InterDigitalInc/CompressAI

Issue with Using CheckerboardMaskedConv2d from layer.py

Closed this issue · 1 comments

In the process of training using mbt2018 as the base model, I attempted to replace the original MaskedConv2d with CheckerboardMaskedConv2d (mask_type=B) as the context model. During testing, I noticed that while the PSNR increased to 33 (at epoch=100), the bpp dropped to 0.13 (in fact, from the first epoch, the bpp was around 0.2, and I set lambda=0.01). Such high results are clearly abnormal. I want to know if there are any issues with directly replacing the usage. Please tell me where the problem lies or if there is any solution.
4
3

This usually occurs when some of the likelihood terms are missing during the bpp_loss computation. Perhaps the likelihood dict is not in the expected format:

assert likelihoods == {
    "y0": torch.Tensor(...),
    "y1": torch.Tensor(...),
    "z": torch.Tensor(...),
    ...
}

Similarly, during model eval, the bpp should measure the length of all the bytestrings. If bpp is also unrealistically small, it might be because it's measuring the length of the keys or list length.


Another alternative is to recursively flatten out the likelihoods/bytestrings:

def flatten_values(x, value_type=object):
    if isinstance(x, list) or isinstance(x, tuple) or isinstance(x, set):
        for v in x:
            yield from flatten_values(v)
    elif isinstance(x, dict):
        for v in x.values():
            yield from flatten_values(v)
    elif isinstance(x, value_type):
        yield x
    else:
        raise ValueError(f"Unexpected type {type(x)}")


class RateDistortionLoss(nn.Module):
    def forward(self, output, target):
        out["bpp_loss"] = sum(
            likelihoods.log2().sum() / -num_pixels
            for likelihoods in flatten_values(output["likelihoods"], torch.Tensor)
        )


def inference(model, x, vbr_stage=None, vbr_scale=None):
    ...
    bpp = sum(len(s) for s in flatten_values(out_enc["strings"], bytes)) * 8.0 / num_pixels

I might look into introducing a PR on this so this bug becomes less likely.