This repo is an unofficial implementation of meta upscale module in paper Meta-SR: A Magnification-Arbitrary Network for Super-Resolution.
Paper link: https://arxiv.org/abs/1903.00875
python packages: torch, basicsr, and other essentials
software: CUDA
hardware: NVIDIA GPU
In addition to the original algorithm, I modified the code to support different scale factors in horizontal and vertical directions. Two versions of implementation is offered in this repo: im2col and loop.
r_v
and r_h
is scale factors in vertical and horizontal direction, seperately.
The input LR (low resolution) feature map x
has the shape of (n, c_in, h, w).
LR feature map x
is first upscaled to HR (high resolution) x_up
by repeating elements, see line 43-45 for detail. The shape of x_up
is (n, c_in, H, W).
Then, nn.functional.unfold
is called to do im2col operation to x_up
. Shape of x_up
changed to (H * W, n, c_in * k * k), see line 46.
Next, get the weight
of convolution and reshape it into (H * W, c_in * k * k, out_c).
Finally, do batch matrix multiplication of x_up
and weight
to get the output. See line 51 and 52.
class Pos2Weight(nn.Module):
"""Pos2Weight module for Meta-SR.
(i/r_v - i//r_v, j/r_h- j//r_h, 1/r_v, 1/r_h) -> W
"""
def __init__(self, in_c, out_c=1, kernel_size=3):
super(Pos2Weight, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.kernel_size = kernel_size
self.meta_block = nn.Sequential(
nn.Linear(4, 256),
nn.ReLU(inplace=True),
nn.Linear(256, kernel_size * kernel_size * in_c * out_c),
)
def forward(self, x):
x = self.meta_block(x)
return x
class MetaUpsample_im2col(nn.Module):
"""Meta Upsample module
"""
def __init__(self, in_c, out_c=1, kernel_size=3, s_v=2., s_h=2.):
super(MetaUpsample, self).__init__()
self.in_c = in_c
self.out_c = out_c
self.kernel_size = kernel_size
self.s_v = s_v
self.s_h = s_h
self.phi = Pos2Weight(in_c, out_c, kernel_size)
def forward(self, x, pos_mat):
"""
Args:
x (torch.Tensor): LR feature map. Shape: (n, in_c, in_h, in_w).
pos_mat (torch.Tensor): Position matrix. Shape: (out_h * out_w, 4)
"""
out_h = int(x.size(2) * self.s_v)
out_w = int(x.size(3) * self.s_h)
n = x.size(0)
v_idxes = np.arange(out_h) // self.s_v
h_idxes = np.arange(out_w) // self.s_h
x_up = x[:, :, v_idxes, :][:, :, :, h_idxes] # (n, in_c, out_h, out_w)
x_up = F.unfold(x_up, self.kernel_size, padding=1).permute(2, 0, 1) # (h_out * w_out, n, c_in * k * k)
weight = self.phi(pos_mat) # (out_h * out_w, k * k * in_c * out_c)
weight = weight.view(out_h * out_w, self.kernel_size * self.kernel_size * self.in_c, self.out_c)
# (h_out * w_out, n, c_in * k * k) @ (h_out * w_out, c_in * k * k, out_c) -> (h_out * w_out, n, out_c)
out = torch.bmm(x_up, weight).permute(1, 2, 0)
return out.contiguous().view(n, self.out_c, out_h, out_w)
Im2col version of implementation is rather fast thanks to pytorch's built-in highly optimized matrix calculation libraries. Nonetheless, the memory consumption of upscaling LR feature map and im2col is rather high, which limits embedding dimensions in Transformer encoders. These 4 lines consume most of GPU memory in training process:
v_idxes = np.arange(out_h) // self.s_v
h_idxes = np.arange(out_w) // self.s_h{
x_up = x[:, :, v_idxes, :][:, :, :, h_idxes] # (n, in_c, out_h, out_w)
x_up = F.unfold(x_up, self.kernel_size, padding=1).permute(2, 0, 1)}
To overcome this issue, the loop version meta upscale implementation is offered.
The loop version is much easier to understand. It just implements the definition of convolution by simply looping over height, width, channel and kernel dimensions, no extra space is used. The demo python code:
def meta_upscale_naive(x, weight, s_v, s_h, batch_size, in_c, out_c, out_h, out_w, kernel_size):
weight = weight.view(out_h, out_w, kernel_size, kernel_size, in_c, out_c)
out = torch.zeros(batch_size, out_c, out_h, out_w, requires_grad=True).cuda()
x = nn.functional.pad(x, (1, 1, 1, 1))
for i in tqdm(range(out_h)):
for j in range(out_w):
i_p = int(i / s_v)
j_p = int(j / s_h)
for k1 in range(kernel_size):
for k2 in range(kernel_size):
for ci in range(in_c):
for co in range(out_c):
out[:, co, i, j] += x[:, ci, i_p + k1, j_p + k2] * weight[i, j, k1, k2, ci, co]
return out
Needless to say, this code is extremely inefficient. A more practical way is to write a new operator in CUDA/C++ and call it in pytorch. See ops/meta_upscale
folder for details.
Here are some tutorials on how to develop a new pytorch operator: https://pytorch.org/tutorials/advanced/cpp_extension.html https://zhuanlan.zhihu.com/p/595851188
The CUDA kernel code:
// x: (n, c_in, h_in + pad, w_in + pad)
// weight: (out_h * out_w, k * k * in_c * out_c)
// out: (n, c_out, h_out, w_out)
__global__ void meta_upscale_forward_kernel(float *x, float *weight, float *out,
float s_v, float s_h,
int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, int ks)
{
// const int tid_h = threadIdx.y;
// const int tid_w = threadIdx.x;
const int i = threadIdx.y + blockIdx.y * blockDim.y;
const int j = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= h_out || j >= w_out) return;
int i_p = i / s_v;
int j_p = j / s_h;
int h_in_pad = h_in + ks - 1;
int w_in_pad = w_in + ks - 1;
for (int ibatch = 0; ibatch < n; ++ibatch)
for (int k1 = 0; k1 < ks; ++k1)
for (int k2 = 0; k2 < ks; ++k2)
for (int ci = 0; ci < c_in; ++ci)
for (int co = 0; co < c_out; ++co)
{
// w: (h_out, w_out, ks, ks, c_in, c_out)
// x: (n, c_in, h_in + pad, w_in + pad)
// w[i][j][k1][k2][ci][co]
int w_idx = co + ci * (c_out) + k2 * (c_out * c_in) \
+ k1 * (c_out * c_in * ks) + j * (c_out * c_in * ks * ks) \
+ i * (c_out * c_in * ks * ks * w_out);
// x[ibatch][ci][i_p + k1][j_p + k2]
int x_idx = (j_p + k2) + (i_p + k1) * (w_in_pad) + ci * (w_in_pad * h_in_pad) + ibatch * (w_in_pad * h_in_pad * c_in);
// out[ibatch][co][i][j]
int out_idx = j + i * (w_out) + co * (w_out * h_out) + ibatch * (w_out * h_out * c_out);
out[out_idx] += weight[w_idx] * x[x_idx];
}
}
__global__ void meta_upscale_backward_kernel(float *dx, float *dweight, float *dout, float *x, float *weight,
float s_v, float s_h,
int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, int ks)
{
const int i = threadIdx.y + blockIdx.y * blockDim.y;
const int j = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= h_out || j >= w_out) return;
int i_p = i / s_v;
int j_p = j / s_h;
int h_in_pad = h_in + ks - 1;
int w_in_pad = w_in + ks - 1;
for (int ibatch = 0; ibatch < n; ++ibatch)
for (int k1 = 0; k1 < ks; ++k1)
for (int k2 = 0; k2 < ks; ++k2)
for (int ci = 0; ci < c_in; ++ci)
for (int co = 0; co < c_out; ++co)
{
int w_idx = co + ci * (c_out) + k2 * (c_out * c_in) \
+ k1 * (c_out * c_in * ks) + j * (c_out * c_in * ks * ks) \
+ i * (c_out * c_in * ks * ks * w_out);
int x_idx = (j_p + k2) + (i_p + k1) * (w_in_pad) + ci * (w_in_pad * h_in_pad) + ibatch * (w_in_pad * h_in_pad * c_in);
int out_idx = j + i * (w_out) + co * (w_out * h_out) + ibatch * (w_out * h_out * c_out);
// Calculate gradients wrt x and weight
dweight[w_idx] += x[x_idx] * dout[out_idx];
atomicAdd(&dx[x_idx], weight[w_idx] * dout[out_idx]);
}
}
Loop version can save up to 40% of GPU memory compared to im2col version. (Tested on RTX4090, the maximum embed_dim in im2col version is 96, and 156 in loop version.)
- Optimize CUDA kernel funtion using shared memory and other techniques.
- Use C++ templates to support different data types (like FP16).