InterDigitalInc/CompressAI

Bug when I try to evaluate a model on my own dataset

AlbertoPresta opened this issue · 2 comments

Hi,

I think I found a bug when I run the following command (suggested by you):

python3 -m compressai.utils.eval_model checkpoint /path/to/images/folder/ -a $ARCH -p $MODEL_CHECKPOINT...

The bug is the following:

Traceback (most recent call last):
File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 310, in
main(sys.argv[1:])
File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 286, in main
model = load_func(*opts, run)
File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 150, in load_checkpoint
return architectures[arch].from_state_dict(state_dict).eval()
File "/opt/conda/lib/python3.7/site-packages/compressai/models/google.py", line 157, in from_state_dict
N = state_dict["g_a.0.weight"].size(0)
KeyError: 'g_a.0.weight'

I think also that the error comes from the fact that you should pass the actual state_dict of the net, which is state_dict["state_dict"], not only state_dict; in my opinion, we should have something like:

N = state_dict["state_dict"]["g_a.0.weight"].size(0)

Maybe I miss something in my previous command.

Alberto

This error usually occurs when the model hasn't been updated via compressai.utils.update_model after training. See here for an example. Currently, this removes one layer of the "state_dict" (i.e. ckpt <- ckpt["state_dict"]) and also updates the CDFs. (Note: if you want to do further training, please create a copy of the checkpoint as a backup before running update_model.)


An (unofficial) alternative that I personally use is just to call model.update(force=True) within load_checkpoint itself, and to remove the extra layer of "state_dict" when loading the checkpoint:

# compressai/utils/eval_model/__main__.py

def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
    ckpt = torch.load(checkpoint_path)
    state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    state_dict = load_state_dict(state_dict)  # for pre-trained models
    model = architectures[arch].from_state_dict(state_dict).eval()
    model.update(force=True)
    return model

I download a pretrained model with the following command:

net = bmshj2018_factorized(quality=8, pretrained=True).eval().to(torch.device("cpu"))
net.update(force=True)

The I try to print out the quantized cdf and I have something like this (I show only the first dimension):

quantized_cdf = Tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 65536, 0, ...,0])
tail_mass = 1.

Moreover, all dimensions have tail_mass equal to 1...what does it mean? should it equal to (or similar to) 1e-9? what is the domain of the quantized cdf if I want to plot it?

thanks in advance for the answers.
Alberto