Bug when I try to evaluate a model on my own dataset
AlbertoPresta opened this issue · 2 comments
Hi,
I think I found a bug when I run the following command (suggested by you):
python3 -m compressai.utils.eval_model checkpoint /path/to/images/folder/ -a $ARCH -p $MODEL_CHECKPOINT...
The bug is the following:
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 310, in
main(sys.argv[1:])
File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 286, in main
model = load_func(*opts, run)
File "/opt/conda/lib/python3.7/site-packages/compressai/utils/eval_model/main.py", line 150, in load_checkpoint
return architectures[arch].from_state_dict(state_dict).eval()
File "/opt/conda/lib/python3.7/site-packages/compressai/models/google.py", line 157, in from_state_dict
N = state_dict["g_a.0.weight"].size(0)
KeyError: 'g_a.0.weight'
I think also that the error comes from the fact that you should pass the actual state_dict of the net, which is state_dict["state_dict"], not only state_dict; in my opinion, we should have something like:
N = state_dict["state_dict"]["g_a.0.weight"].size(0)
Maybe I miss something in my previous command.
Alberto
This error usually occurs when the model hasn't been updated via compressai.utils.update_model
after training. See here for an example. Currently, this removes one layer of the "state_dict"
(i.e. ckpt <- ckpt["state_dict"]
) and also updates the CDFs. (Note: if you want to do further training, please create a copy of the checkpoint as a backup before running update_model
.)
An (unofficial) alternative that I personally use is just to call model.update(force=True)
within load_checkpoint
itself, and to remove the extra layer of "state_dict"
when loading the checkpoint:
# compressai/utils/eval_model/__main__.py
def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module:
ckpt = torch.load(checkpoint_path)
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
state_dict = load_state_dict(state_dict) # for pre-trained models
model = architectures[arch].from_state_dict(state_dict).eval()
model.update(force=True)
return model
I download a pretrained model with the following command:
net = bmshj2018_factorized(quality=8, pretrained=True).eval().to(torch.device("cpu"))
net.update(force=True)
The I try to print out the quantized cdf and I have something like this (I show only the first dimension):
quantized_cdf = Tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 65536, 0, ...,0])
tail_mass = 1.
Moreover, all dimensions have tail_mass equal to 1...what does it mean? should it equal to (or similar to) 1e-9? what is the domain of the quantized cdf if I want to plot it?
thanks in advance for the answers.
Alberto