dimension y_hat before and after entropy coding
Closed this issue · 2 comments
Hi,
Sorry to bother you again. I just have one question about the STF model (but I suppose it could also apply for the CNN one).
When I do print(y_hat.shape)
before and after the entropy coding, in the forward(self, x)
function of the SymmetricalTransFormer class (in the stf.py file), I don't get the same dimensions....
I get before entropy coding [1, 384, 48, 80] and after entropy decoding [1, 384, 640, 48] (for an input tensor of shape [1, 3, 768, 1280])
Aren't they supposed to be the same ? Can you please tell me what I'm doing wrong ?
Thank You !!
Here's your code with my "print"s :
`
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
for i in range(self.num_layers):
layer = self.layers[i]
x, Wh, Ww = layer(x, Wh, Ww)
y = x
C = self.embed_dim * 8
y = y.view(-1, Wh, Ww, C).permute(0, 3, 1, 2).contiguous()
y_shape = y.shape[2:]
#######
print("y.shape : ", y.shape)
#########
z = self.h_a(y)
_, z_likelihoods = self.entropy_bottleneck(z)
z_offset = self.entropy_bottleneck._get_medians()
z_tmp = z - z_offset
z_hat = ste_round(z_tmp) + z_offset
latent_scales = self.h_scale_s(z_hat)
latent_means = self.h_mean_s(z_hat)
y_slices = y.chunk(self.num_slices, 1)
y_hat_slices = []
y_likelihood = []
for slice_index, y_slice in enumerate(y_slices):
support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
mean_support = torch.cat([latent_means] + support_slices, dim=1)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :, :y_shape[0], :y_shape[1]]
scale_support = torch.cat([latent_scales] + support_slices, dim=1)
scale = self.cc_scale_transforms[slice_index](scale_support)
scale = scale[:, :, :y_shape[0], :y_shape[1]]
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
y_likelihood.append(y_slice_likelihood)
y_hat_slice = ste_round(y_slice - mu) + mu
lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * torch.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
y_hat = torch.cat(y_hat_slices, dim=1)
y_likelihoods = torch.cat(y_likelihood, dim=1)
y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh*Ww, C)
for i in range(self.num_layers):
layer = self.syn_layers[i]
y_hat, Wh, Ww = layer(y_hat, Wh, Ww)
###########
print("y_hat.shape : ", (y_hat.view(-1, Wh, Ww, self.embed_dim)).shape)
###########
x_hat = self.end_conv(y_hat.view(-1, Wh, Ww, self.embed_dim).permute(0, 3, 1, 2).contiguous())
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}`
The way variables are named may cause the misunderstanding, sorry for that.
"print("y_hat.shape : ", (y_hat.view(-1, Wh, Ww, self.embed_dim)).shape)"
"y_hat" here is not the 'y_hat' you understand. In fact, it is a intermediate feature map.
You should print the shape at y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh*Ww, C).
Hi,
It works !
Thank you.