InterDigitalInc/CompressAI

Issue Regarding the Use of GaussianMixtureConditional

formioq opened this issue · 6 comments

I want to use the GaussianMixtureConditional component from the library, but I found that its usage seems to conflict with its parent class, GaussianConditional. Specifically, in the likelihood function of GaussianMixtureConditional, the line M = inputs.size(1) and the subsequent slicing behavior appear to assume a single Gaussian distribution (as it slices from 0 to the channel size of inputs, which can only be split into one part). When I modify it to M = inputs.size(1) // self.K to enable slicing into K parts (the number of Gaussian distributions), it causes a channel dimension mismatch in the parent class's likelihood function, specifically in values = inputs - means (since only means are sliced, but not inputs). I would like to know if my understanding of the code is incorrect or if it actually cannot handle multiple Gaussian distributions?
2
1

M is the number of channels in y.

B, M, H, W = y.shape

Please ensure:

assert scales.shape[1] == M * K
assert means.shape[1] == M * K

Thank you very much for your answer; it is now working correctly. However, I noticed that the bpp (calculated using your RateDistortionLoss) seems to have increased significantly (at lambda=0.01, the bpp was around 0.45 when using GC, but it is over 10 after using GMM). Do you know a solution to this problem?

Are you training with the same settings? After the first epoch of training, bpp_loss should usually be ≤ 3.

The only difference between the two experiments should be:

  • GC. No changes.
  • GMM. Layer that generates means/scales has 3x output channels.

What architecture is your model based on (e.g., mbt2018-mean, cheng2020-anchor, elic2022-chandelier, ...)?

Yes, I am using the 'mbt2018' as the base model (with a context model), but I modified the output of entropy_parameters to be M*9 because GMM seems to require an additional weight parameter. I used chunk(3,1) to divide it into scales_hat, means_hat, and weight_hat, each part having M*3 channels to facilitate subsequent modeling with K=3 GMM. All other settings remained unchanged in both experiments. Here is the modified code (with the commented-out part being the changes for running under GMM conditions,I also used softmax to ensure the weights sum to 1). Could it be an issue related to the weight settings? Or is there a logical misunderstanding on my part?
I look forward to your response.
`
r"""

..
              ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
        x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯·───►───┤h_s├─┐
              └───┘    │     └───┘     └───┘        EB        └───┘ │
                       ▼                                            │
                     ┌─┴─┐                                          │
                     │ Q │                                   params ▼
                     └─┬─┘                                          │
                 y_hat ▼                  ┌─────┐                   │
                       ├──────────►───────┤  CP ├────────►──────────┤
                       │                  └─────┘                   │
                       ▼                                            ▼
                       │                                            │
                       ·                  ┌─────┐                   │
                    GC : ◄────────◄───────┤  EP ├────────◄──────────┘
                       ·     scales_hat   └─────┘
                       │      means_hat
                 y_hat ▼
                       │
              ┌───┐    │
    x_hat ──◄─┤g_s├────┘
              └───┘

    EB = Entropy bottleneck
    GC = Gaussian conditional
    EP = Entropy parameters network
    CP = Context prediction (checkboard)
"""

def init(self, N, M, **kwargs):
super().init(**kwargs)

    self.entropy_bottleneck = EntropyBottleneck(N)

    self.g_a = nn.Sequential(
        conv(3, N),
        GDN(N),
        conv(N, N),
        GDN(N),
        conv(N, N),
        GDN(N),
        conv(N, M),
    )

    self.g_s = nn.Sequential(
        deconv(M, N),
        GDN(N, inverse=True),
        deconv(N, N),
        GDN(N, inverse=True),
        deconv(N, N),
        GDN(N, inverse=True),
        deconv(N, 3),
    )

    self.h_a = nn.Sequential(
        conv(M, N, stride=1, kernel_size=3),
        nn.ReLU(inplace=True),
        conv(N, N),
        nn.ReLU(inplace=True),
        conv(N, N),
    )

    self.h_s = nn.Sequential(
        deconv(N, N),
        nn.ReLU(inplace=True),
        deconv(N, M*3//2),
        nn.ReLU(inplace=True),
        conv(M*3//2, M *2 , stride=1, kernel_size=3),
        nn.ReLU(inplace=True),
    )

    self.entropy_parameters = nn.Sequential(
        nn.Conv2d(M * 12 // 3, M * 10 // 3, 1),
        nn.LeakyReLU(inplace=True),
        nn.Conv2d(M * 10 // 3, M * 8 // 3, 1),
        nn.LeakyReLU(inplace=True),
        nn.Conv2d(M * 8 // 3, M * 6 //3 , 1),

        #nn.Conv2d(M * 8 // 3, M * 9 , 1),

    )

    self.context_prediction = CheckerboardMaskedConv2d(
        M, 2 * M, kernel_size=5, padding=2, stride=1,mask_type='B'
    )
    self.gaussian_conditional = GaussianConditional(None)

    #self.gaussian_conditional = GaussianMixtureConditional(K=3)

    self.N = int(N)
    self.M = int(M)

@property
def downsampling_factor(self) -> int:
    return 2 ** (4 + 2)

def forward(self, x):
    y = self.g_a(x)
    z = self.h_a(torch.abs(y))
    z_hat, z_likelihoods = self.entropy_bottleneck(z)

    params=self.h_s(z_hat)
    y_hat = self.gaussian_conditional.quantize(y,"noise" if self.training else "dequantize"
    )
    ctx_hat=self.context_prediction(y_hat)
    gaussian_params = self.entropy_parameters(
        torch.cat((params, ctx_hat), dim=1)
    )

    #scales_hat, means_hat, weight_hat = gaussian_params.chunk(3, 1)

    scales_hat, means_hat = gaussian_params.chunk(2, 1)

    #weight_hat = F.softmax(weight_hat, dim=1)

    y_hat1, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)

    #y_hat1, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat, weights=weight_hat)

    x_hat = self.g_s(y_hat1)

    return {
        "x_hat": x_hat,
        "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
    }`

The weight_hat softmax should be done along a dimension of length 3.

You may need to reshape and apply along the 3-length dimension and then reshape back so that it's compatible with the GCM interface. Perhaps:

B, Mx3, H, W = weight_hat.shape

weight_hat = F.softmax(
    weight_hat.reshape(B, 3, M, H, W), dim=1
).reshape(B, Mx3, H, W)

The problem is solved, thank you very much, you are really amazing