Googolxx/STF

dimension y_hat before and after entropy coding

Closed this issue · 2 comments

Hi,

Sorry to bother you again. I just have one question about the STF model (but I suppose it could also apply for the CNN one).

When I do print(y_hat.shape) before and after the entropy coding, in the forward(self, x) function of the SymmetricalTransFormer class (in the stf.py file), I don't get the same dimensions....
I get before entropy coding [1, 384, 48, 80] and after entropy decoding [1, 384, 640, 48] (for an input tensor of shape [1, 3, 768, 1280])

Aren't they supposed to be the same ? Can you please tell me what I'm doing wrong ?

Thank You !!

Here's your code with my "print"s :

`
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)

    Wh, Ww = x.size(2), x.size(3)
    x = x.flatten(2).transpose(1, 2)
    x = self.pos_drop(x)
    for i in range(self.num_layers):
        layer = self.layers[i]
        x, Wh, Ww = layer(x, Wh, Ww)

    y = x
    C = self.embed_dim * 8
    y = y.view(-1, Wh, Ww, C).permute(0, 3, 1, 2).contiguous()
    y_shape = y.shape[2:]

    ####### 
    print("y.shape : ", y.shape)
    #########

    z = self.h_a(y)
    _, z_likelihoods = self.entropy_bottleneck(z)
    z_offset = self.entropy_bottleneck._get_medians()
    z_tmp = z - z_offset
    z_hat = ste_round(z_tmp) + z_offset

    latent_scales = self.h_scale_s(z_hat)
    latent_means = self.h_mean_s(z_hat)

    y_slices = y.chunk(self.num_slices, 1)
    y_hat_slices = []
    y_likelihood = []

    for slice_index, y_slice in enumerate(y_slices):
        support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
        mean_support = torch.cat([latent_means] + support_slices, dim=1)
        mu = self.cc_mean_transforms[slice_index](mean_support)
        mu = mu[:, :, :y_shape[0], :y_shape[1]]

        scale_support = torch.cat([latent_scales] + support_slices, dim=1)
        scale = self.cc_scale_transforms[slice_index](scale_support)
        scale = scale[:, :, :y_shape[0], :y_shape[1]]

        _, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)

        y_likelihood.append(y_slice_likelihood)
        y_hat_slice = ste_round(y_slice - mu) + mu

        lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
        lrp = self.lrp_transforms[slice_index](lrp_support)
        lrp = 0.5 * torch.tanh(lrp)
        y_hat_slice += lrp

        y_hat_slices.append(y_hat_slice)

    y_hat = torch.cat(y_hat_slices, dim=1)
    y_likelihoods = torch.cat(y_likelihood, dim=1)

    y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh*Ww, C)
    for i in range(self.num_layers):
        layer = self.syn_layers[i]
        y_hat, Wh, Ww = layer(y_hat, Wh, Ww)

    ###########
    print("y_hat.shape : ", (y_hat.view(-1, Wh, Ww, self.embed_dim)).shape)
    ###########

    x_hat = self.end_conv(y_hat.view(-1, Wh, Ww, self.embed_dim).permute(0, 3, 1, 2).contiguous())
    return {
        "x_hat": x_hat,
        "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
    }`

The way variables are named may cause the misunderstanding, sorry for that.
"print("y_hat.shape : ", (y_hat.view(-1, Wh, Ww, self.embed_dim)).shape)"
"y_hat" here is not the 'y_hat' you understand. In fact, it is a intermediate feature map.
You should print the shape at y_hat = y_hat.permute(0, 2, 3, 1).contiguous().view(-1, Wh*Ww, C).

Hi,
It works !
Thank you.