Hi, I don't understand permute function
MrLinNing opened this issue · 1 comments
MrLinNing commented
Hello, @jhjacobsen
I am a newcomer to deep learning. I read your source code and there are a few things I don’t understand.
Why is the permute function so used?
class injective_pad(nn.Module):
def __init__(self, pad_size):
super(injective_pad, self).__init__()
self.pad_size = pad_size
self.pad = nn.ZeroPad2d((0, 0, 0, pad_size))
def forward(self, x):
x = x.permute(0, 2, 1, 3)
x = self.pad(x)
return x.permute(0, 2, 1, 3)
def inverse(self, x):
return x[:, :x.size(1) - self.pad_size, :, :]
And this,
class psi(nn.Module):
def __init__(self, block_size):
super(psi, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def inverse(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / self.block_size_sq)
s_width = int(d_width * self.block_size)
s_height = int(d_height * self.block_size)
t_1 = output.contiguous().view(batch_size, d_height, d_width, self.block_size_sq, s_depth)
spl = t_1.split(self.block_size, 3)
stack = [t_t.contiguous().view(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output.contiguous()
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output.contiguous()
What is the role of these two class functions? Thank you.
jhjacobsen commented
Permute is used because pytorch pad only permits 2d padding, but we want to pad zeros on the channel axis. So we permute channels to the position of space and pad, permute back.
The psi class is explained in figure 2 of our paper: https://arxiv.org/abs/1802.07088
We call it invertible downsampling, but is also known as "squeezing" from Real-NVP, or space_to_depth and depth_to_space functions of tensorflow as used in subpixel convolutions.