locuslab/deq

Test ImageNet Pre-trained Model

Closed this issue · 10 comments

Hi, I tried to test the pretrained models MDEQ_XL_Cls.pkl. However, I got size mismatch errors between the weights of the checkpoint model and the model in the code.

I download and run the command: python tools/cls_valid.py --testModel pretrained_models/MDEQ_XL_Cls.pkl --cfg experiments/imagenet/cls_mdeq_XL.yaml

Parts of error log:
size mismatch for downsample.0.weight: copying a param with shape torch.Size([88, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 3, 3, 3]). size mismatch for downsample.1.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.3.weight: copying a param with shape torch.Size([88, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]). size mismatch for downsample.4.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.0.weight: copying a param with shape torch.Size([88, 88, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]). size mismatch for stage0.1.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for fullstage.branches.0.blocks.0.conv1.weight_g: copying a param with shape torch.Size([528, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 1, 1, 1]). size mismatch for fullstage.branches.0.blocks.0.conv1.weight_v: copying a param with shape torch.Size([528, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([400 , 80, 3, 3]). size mismatch for fullstage.branches.0.blocks.0.gn1.weight: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([400]). size mismatch for fullstage.branches.0.blocks.0.gn1.bias: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([400]). size mismatch for fullstage.branches.0.blocks.0.conv2.weight_g: copying a param with shape torch.Size([88, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 1 , 1, 1]). size mismatch for fullstage.branches.0.blocks.0.conv2.weight_v: copying a param with shape torch.Size([88, 528, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 400, 3, 3]).

Hi @HieuPhan33 ,

Thanks for pointing this out and it's due to a discrepancy I created when merging the repo (see here for the original released MDEQ). I have fixed the yaml file and you can pull it again. Sorry for the confusion!

Hi Shaojie,

The discrepancy in the downsame block is gone now, a good step forward. But some mismatches still exist in the fullstage.branches. Could you please double-check? Thanks.

    size mismatch for fullstage.branches.0.blocks.0.conv1.weight_g: copying a param with shape torch.Size([528, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([440, 1, 1, 1]).
    size mismatch for fullstage.branches.0.blocks.0.conv1.weight_v: copying a param with shape torch.Size([528, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([440, 88, 3, 3]).
    size mismatch for fullstage.branches.0.blocks.0.gn1.weight: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([440]).
    size mismatch for fullstage.branches.0.blocks.0.gn1.bias: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([440]).
    size mismatch for fullstage.branches.0.blocks.0.conv2.weight_v: copying a param with shape torch.Size([88, 528, 3, 3]) from checkpoint, the shape in current model is torch.Size([88, 440, 3, 3]).
    size mismatch for fullstage.branches.1.blocks.0.conv1.weight_g: copying a param with shape torch.Size([1056, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([880, 1, 1, 1]).
    size mismatch for fullstage.branches.1.blocks.0.conv1.weight_v: copying a param with shape torch.Size([1056, 176, 3, 3]) from checkpoint, the shape in current model is torch.Size([880, 176, 3, 3]).
    size mismatch for fullstage.branches.1.blocks.0.gn1.weight: copying a param with shape torch.Size([1056]) from checkpoint, the shape in current model is torch.Size([880]).
    size mismatch for fullstage.branches.1.blocks.0.gn1.bias: copying a param with shape torch.Size([1056]) from checkpoint, the shape in current model is torch.Size([880]).
    size mismatch for fullstage.branches.1.blocks.0.conv2.weight_v: copying a param with shape torch.Size([176, 1056, 3, 3]) from checkpoint, the shape in current model is torch.Size([176, 880, 3, 3]).
    size mismatch for fullstage.branches.2.blocks.0.conv1.weight_g: copying a param with shape torch.Size([2112, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([1760, 1, 1, 1]).
    size mismatch for fullstage.branches.2.blocks.0.conv1.weight_v: copying a param with shape torch.Size([2112, 352, 3, 3]) from checkpoint, the shape in current model is torch.Size([1760, 352, 3, 3]).
    size mismatch for fullstage.branches.2.blocks.0.gn1.weight: copying a param with shape torch.Size([2112]) from checkpoint, the shape in current model is torch.Size([1760]).
    size mismatch for fullstage.branches.2.blocks.0.gn1.bias: copying a param with shape torch.Size([2112]) from checkpoint, the shape in current model is torch.Size([1760]).
    size mismatch for fullstage.branches.2.blocks.0.conv2.weight_v: copying a param with shape torch.Size([352, 2112, 3, 3]) from checkpoint, the shape in current model is torch.Size([352, 1760, 3, 3]).
    size mismatch for fullstage.branches.3.blocks.0.conv1.weight_g: copying a param with shape torch.Size([4224, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([3520, 1, 1, 1]).
    size mismatch for fullstage.branches.3.blocks.0.conv1.weight_v: copying a param with shape torch.Size([4224, 704, 3, 3]) from checkpoint, the shape in current model is torch.Size([3520, 704, 3, 3]).
    size mismatch for fullstage.branches.3.blocks.0.gn1.weight: copying a param with shape torch.Size([4224]) from checkpoint, the shape in current model is torch.Size([3520]).
    size mismatch for fullstage.branches.3.blocks.0.gn1.bias: copying a param with shape torch.Size([4224]) from checkpoint, the shape in current model is torch.Size([3520]).
    size mismatch for fullstage.branches.3.blocks.0.conv2.weight_v: copying a param with shape torch.Size([704, 4224, 3, 3]) from checkpoint, the shape in current model is torch.Size([704, 3520, 3, 3]).

Ah, sorry, still the discrepancy issue 😄 The EXPANSION_FACTOR should be 6, not 5. Fixed now. https://github.com/locuslab/deq/blob/master/MDEQ-Vision/experiments/imagenet/cls_mdeq_XL.yaml#L16

Hi, the size mismatch is completely gone now, good job. But some names between the checkpoint model and the actual model are unmatched. This leads to the following Unexpected keys in state dict issues:
Unexpected key(s) in state_dict: "fullstage_copy.branches.0.blocks.0.conv1.weight", "fullstage_copy.branches.0.blocks.0.gn1.weight", "fullstage_copy.branches.0.blocks.0.gn1.bias", " fullstage_copy.branches.0.blocks.0.conv2.weight", "fullstage_copy.branches.0.blocks.0.gn2.weight", "fullstage_copy.branches.0.blocks.0.gn2.bias", "fullstage_copy.branches.0.blocks.0.gn3.wei ght", "fullstage_copy.branches.0.blocks.0.gn3.bias", "fullstage_copy.branches.1.blocks.0.conv1.weight", "fullstage_copy.branches.1.blocks.0.gn1.weight", "fullstage_copy.branches.1.blocks.0. gn1.bias", "fullstage_copy.branches.1.blocks.0.conv2.weight", "fullstage_copy.branches.1.blocks.0.gn2.weight", "fullstage_copy.branches.1.blocks.0.gn2.bias", "fullstage_copy.branches.1.bloc ks.0.gn3.weight", "fullstage_copy.branches.1.blocks.0.gn3.bias", "fullstage_copy.branches.2.blocks.0.conv1.weight", "fullstage_copy.branches.2.blocks.0.gn1.weight", "fullstage_copy.branches .2.blocks.0.gn1.bias", "fullstage_copy.branches.2.blocks.0.conv2.weight", "fullstage_copy.branches.2.blocks.0.gn2.weight", "fullstage_copy.branches.2.blocks.0.gn2.bias", "fullstage_copy.bra nches.2.blocks.0.gn3.weight", "fullstage_copy.branches.2.blocks.0.gn3.bias", "fullstage_copy.branches.3.blocks.0.conv1.weight", "fullstage_copy.branches.3.blocks.0.gn1.weight", "fullstage_c opy.branches.3.blocks.0.gn1.bias", "fullstage_copy.branches.3.blocks.0.conv2.weight", "fullstage_copy.branches.3.blocks.0.gn2.weight", "fullstage_copy.branches.3.blocks.0.gn2.bias", "fullst age_copy.branches.3.blocks.0.gn3.weight", "fullstage_copy.branches.3.blocks.0.gn3.bias", "fullstage_copy.fuse_layers.0.1.net.conv.weight", "fullstage_copy.fuse_layers.0.1.net.gnorm.weight", "fullstage_copy.fuse_layers.0.1.net.gnorm.bias", "fullstage_copy.fuse_layers.0.2.net.conv.weight", "fullstage_copy.fuse_layers.0.2.net.gnorm.weight", "fullstage_copy.fuse_layers.0.2.net.gn orm.bias", "fullstage_copy.fuse_layers.0.3.net.conv.weight", "fullstage_copy.fuse_layers.0.3.net.gnorm.weight", "fullstage_copy.fuse_layers.0.3.net.gnorm.bias", "fullstage_copy.fuse_layers. 1.0.net.0.conv.weight", "fullstage_copy.fuse_layers.1.0.net.0.gnorm.weight", "fullstage_copy.fuse_layers.1.0.net.0.gnorm.bias", "fullstage_copy.fuse_layers.1.2.net.conv.weight", "fullstage_ copy.fuse_layers.1.2.net.gnorm.weight", "fullstage_copy.fuse_layers.1.2.net.gnorm.bias", "fullstage_copy.fuse_layers.1.3.net.conv.weight", "fullstage_copy.fuse_layers.1.3.net.gnorm.weight", "fullstage_copy.fuse_layers.1.3.net.gnorm.bias", "fullstage_copy.fuse_layers.2.0.net.0.conv.weight", "fullstage_copy.fuse_layers.2.0.net.0.gnorm.weight", "fullstage_copy.fuse_layers.2.0.ne t.0.gnorm.bias", "fullstage_copy.fuse_layers.2.0.net.1.conv.weight", "fullstage_copy.fuse_layers.2.0.net.1.gnorm.weight", "fullstage_copy.fuse_layers.2.0.net.1.gnorm.bias", "fullstage_copy. fuse_layers.2.1.net.0.conv.weight", "fullstage_copy.fuse_layers.2.1.net.0.gnorm.weight", "fullstage_copy.fuse_layers.2.1.net.0.gnorm.bias", "fullstage_copy.fuse_layers.2.3.net.conv.weight", "fullstage_copy.fuse_layers.2.3.net.gnorm.weight", "fullstage_copy.fuse_layers.2.3.net.gnorm.bias", "fullstage_copy.fuse_layers.3.0.net.0.conv.weight", "fullstage_copy.fuse_layers.3.0.net. 0.gnorm.weight", "fullstage_copy.fuse_layers.3.0.net.0.gnorm.bias", "fullstage_copy.fuse_layers.3.0.net.1.conv.weight", "fullstage_copy.fuse_layers.3.0.net.1.gnorm.weight", "fullstage_copy. fuse_layers.3.0.net.1.gnorm.bias", "fullstage_copy.fuse_layers.3.0.net.2.conv.weight", "fullstage_copy.fuse_layers.3.0.net.2.gnorm.weight", "fullstage_copy.fuse_layers.3.0.net.2.gnorm.bias" , "fullstage_copy.fuse_layers.3.1.net.0.conv.weight", "fullstage_copy.fuse_layers.3.1.net.0.gnorm.weight", "fullstage_copy.fuse_layers.3.1.net.0.gnorm.bias", "fullstage_copy.fuse_layers.3.1 .net.1.conv.weight", "fullstage_copy.fuse_layers.3.1.net.1.gnorm.weight", "fullstage_copy.fuse_layers.3.1.net.1.gnorm.bias", "fullstage_copy.fuse_layers.3.2.net.0.conv.weight", "fullstage_c opy.fuse_layers.3.2.net.0.gnorm.weight", "fullstage_copy.fuse_layers.3.2.net.0.gnorm.bias", "fullstage_copy.post_fuse_layers.0.conv.weight", "fullstage_copy.post_fuse_layers.0.gnorm.weight" , "fullstage_copy.post_fuse_layers.0.gnorm.bias", "fullstage_copy.post_fuse_layers.1.conv.weight", "fullstage_copy.post_fuse_layers.1.gnorm.weight", "fullstage_copy.post_fuse_layers.1.gnorm .bias", "fullstage_copy.post_fuse_layers.2.conv.weight", "fullstage_copy.post_fuse_layers.2.gnorm.weight", "fullstage_copy.post_fuse_layers.2.gnorm.bias", "fullstage_copy.post_fuse_layers.3 .conv.weight", "fullstage_copy.post_fuse_layers.3.gnorm.weight", "fullstage_copy.post_fuse_layers.3.gnorm.bias"

Two ways to fix this:

  1. Re-download the pretrained model. I just uploaded a version that I tested to be okay (getting 79% on ImageNet). The new link is in the README.

  2. Run the following lines with path setting to your current mdeq_XL_cls.pth:

import os
import torch

pd = torch.load('pretrained_models/mdeq_XL.pth')
new_pd = {}
for k in pd:
    if "copy" not in k and "deq" not in k:
            new_pd[k] = pd[k].clone().detach().cpu()
            count += pd[k].nelement()

torch.save(new_pd, f'pretrained_models/mdeq_XL_new.pth')

And then use mdeq_XL_new.pth.

Hi @jerrybai1995.
One step forward, but not there yet. I tried both ways, and there few remaining errors:
Unexpected key(s) in state_dict: "fullstage.post_fuse_layers.0.gnorm.weight", "fullstage.post_fuse_layers.0.gnorm.bias", "fullstage.post_fuse_layers.1.gnorm.weight", "fullstage.post_fuse_layers.1.gnorm.bias", "fullstage.post_fuse_layers.2.gnorm.weight", "fullstage.post_fuse_layers.2.gnorm.bias", "fullstage.post_fuse_layers.3.gnorm.weight", "fullstage.post_fuse_layers.3.gnorm.bias"

Wonderful. It's working right now. Thanks for your support.

Hi, it seems that the pre-trained model for MDEQ-Small (maybe should check all of them) also need updated.

RuntimeError: Error(s) in loading state_dict for MDEQClsNet: Unexpected key(s) in state_dict: "fullstage_copy.branches.0.blocks.0.conv1.weight", "fullstage_copy.branches.0.blocks.0.gn1.weight", "fullstage_copy.branches.0.blocks.0.gn1.bias", "fullstage_copy.branches.0.blocks.0.conv2.weight", "fullstage_copy.branches.0.blocks.0.gn2.weight", "fullstage_copy.branches.0.blocks.0.gn2.bias", "fullstage_copy.branches.0.blocks.0.gn3.weight", "fullstage_copy.branches.0.blocks.0.gn3.bias", "fullstage_copy.branches.1.blocks.0.conv1.weight", "fullstage_copy.branches.1.blocks.0.gn1.weight", "fullstage_copy.branches.1.blocks.0.gn1.bias", "fullstage_copy.branches.1.blocks.0.conv2.weight", "fullstage_copy.branches.1.blocks.0.gn2.weight", "fullstage_copy.branches.1.blocks.0.gn2.bias", "fullstage_copy.branches.1.blocks.0.gn3.weight", "fullstage_copy.branches.1.blocks.0.gn3.bias", "fullstage_copy.branches.2.blocks.0.conv1.weight", "fullstage_copy.branches.2.blocks.0.gn1.weight", "fullstage_copy.branches.2.blocks.0.gn1.bias", "fullstage_copy.branches.2.blocks.0.conv2.weight", "fullstage_copy.branches.2.blocks.0.gn2.weight", "fullstage_copy.branches.2.blocks.0.gn2.bias", "fullstage_copy.branches.2.blocks.0.gn3.weight", "fullstage_copy.branches.2.blocks.0.gn3.bias", "fullstage_copy.branches.3.blocks.0.conv1.weight", "fullstage_copy.branches.3.blocks.0.gn1.weight", "fullstage_copy.branches.3.blocks.0.gn1.bias", "fullstage_copy.branches.3.blocks.0.conv2.weight", "fullstage_copy.branches.3.blocks.0.gn2.weight", "fullstage_copy.branches.3.blocks.0.gn2.bias", "fullstage_copy.branches.3.blocks.0.gn3.weight", "fullstage_copy.branches.3.blocks.0.gn3.bias", "fullstage_copy.fuse_layers.0.1.net.conv.weight", "fullstage_copy.fuse_layers.0.1.net.gnorm.weight", "fullstage_copy.fuse_layers.0.1.net.gnorm.bias", "fullstage_copy.fuse_layers.0.2.net.conv.weight", "fullstage_copy.fuse_layers.0.2.net.gnorm.weight", "fullstage_copy.fuse_layers.0.2.net.gnorm.bias", "fullstage_copy.fuse_layers.0.3.net.conv.weight", "fullstage_copy.fuse_layers.0.3.net.gnorm.weight", "fullstage_copy.fuse_layers.0.3.net.gnorm.bias", "fullstage_copy.fuse_layers.1.0.net.0.conv.weight", "fullstage_copy.fuse_layers.1.0.net.0.gnorm.weight", "fullstage_copy.fuse_layers.1.0.net.0.gnorm.bias", "fullstage_copy.fuse_layers.1.2.net.conv.weight", "fullstage_copy.fuse_layers.1.2.net.gnorm.weight", "fullstage_copy.fuse_layers.1.2.net.gnorm.bias", "fullstage_copy.fuse_layers.1.3.net.conv.weight", "fullstage_copy.fuse_layers.1.3.net.gnorm.weight", "fullstage_copy.fuse_layers.1.3.net.gnorm.bias", "fullstage_copy.fuse_layers.2.0.net.0.conv.weight", "fullstage_copy.fuse_layers.2.0.net.0.gnorm.weight", "fullstage_copy.fuse_layers.2.0.net.0.gnorm.bias", "fullstage_copy.fuse_layers.2.0.net.1.conv.weight", "fullstage_copy.fuse_layers.2.0.net.1.gnorm.weight", "fullstage_copy.fuse_layers.2.0.net.1.gnorm.bias", "fullstage_copy.fuse_layers.2.1.net.0.conv.weight", "fullstage_copy.fuse_layers.2.1.net.0.gnorm.weight", "fullstage_copy.fuse_layers.2.1.net.0.gnorm.bias", "fullstage_copy.fuse_layers.2.3.net.conv.weight", "fullstage_copy.fuse_layers.2.3.net.gnorm.weight", "fullstage_copy.fuse_layers.2.3.net.gnorm.bias", "fullstage_copy.fuse_layers.3.0.net.0.conv.weight", "fullstage_copy.fuse_layers.3.0.net.0.gnorm.weight", "fullstage_copy.fuse_layers.3.0.net.0.gnorm.bias", "fullstage_copy.fuse_layers.3.0.net.1.conv.weight", "fullstage_copy.fuse_layers.3.0.net.1.gnorm.weight", "fullstage_copy.fuse_layers.3.0.net.1.gnorm.bias", "fullstage_copy.fuse_layers.3.0.net.2.conv.weight", "fullstage_copy.fuse_layers.3.0.net.2.gnorm.weight", "fullstage_copy.fuse_layers.3.0.net.2.gnorm.bias", "fullstage_copy.fuse_layers.3.1.net.0.conv.weight", "fullstage_copy.fuse_layers.3.1.net.0.gnorm.weight", "fullstage_copy.fuse_layers.3.1.net.0.gnorm.bias", "fullstage_copy.fuse_layers.3.1.net.1.conv.weight", "fullstage_copy.fuse_layers.3.1.net.1.gnorm.weight", "fullstage_copy.fuse_layers.3.1.net.1.gnorm.bias", "fullstage_copy.fuse_layers.3.2.net.0.conv.weight", "fullstage_copy.fuse_layers.3.2.net.0.gnorm.weight", "fullstage_copy.fuse_layers.3.2.net.0.gnorm.bias", "fullstage_copy.post_fuse_layers.0.conv.weight", "fullstage_copy.post_fuse_layers.0.gnorm.weight", "fullstage_copy.post_fuse_layers.0.gnorm.bias", "fullstage_copy.post_fuse_layers.1.conv.weight", "fullstage_copy.post_fuse_layers.1.gnorm.weight", "fullstage_copy.post_fuse_layers.1.gnorm.bias", "fullstage_copy.post_fuse_layers.2.conv.weight", "fullstage_copy.post_fuse_layers.2.gnorm.weight", "fullstage_copy.post_fuse_layers.2.gnorm.bias", "fullstage_copy.post_fuse_layers.3.conv.weight", "fullstage_copy.post_fuse_layers.3.gnorm.weight", "fullstage_copy.post_fuse_layers.3.gnorm.bias", "deq.func.branches.0.blocks.0.conv1.weight_g", "deq.func.branches.0.blocks.0.conv1.weight_v", "deq.func.branches.0.blocks.0.gn1.weight", "deq.func.branches.0.blocks.0.gn1.bias", "deq.func.branches.0.blocks.0.conv2.weight_g", "deq.func.branches.0.blocks.0.conv2.weight_v", "deq.func.branches.0.blocks.0.gn2.weight", "deq.func.branches.0.blocks.0.gn2.bias", "deq.func.branches.0.blocks.0.gn3.weight", "deq.func.branches.0.blocks.0.gn3.bias", "deq.func.branches.1.blocks.0.conv1.weight_g", "deq.func.branches.1.blocks.0.conv1.weight_v", "deq.func.branches.1.blocks.0.gn1.weight", "deq.func.branches.1.blocks.0.gn1.bias", "deq.func.branches.1.blocks.0.conv2.weight_g", "deq.func.branches.1.blocks.0.conv2.weight_v", "deq.func.branches.1.blocks.0.gn2.weight", "deq.func.branches.1.blocks.0.gn2.bias", "deq.func.branches.1.blocks.0.gn3.weight", "deq.func.branches.1.blocks.0.gn3.bias", "deq.func.branches.2.blocks.0.conv1.weight_g", "deq.func.branches.2.blocks.0.conv1.weight_v", "deq.func.branches.2.blocks.0.gn1.weight", "deq.func.branches.2.blocks.0.gn1.bias", "deq.func.branches.2.blocks.0.conv2.weight_g", "deq.func.branches.2.blocks.0.conv2.weight_v", "deq.func.branches.2.blocks.0.gn2.weight", "deq.func.branches.2.blocks.0.gn2.bias", "deq.func.branches.2.blocks.0.gn3.weight", "deq.func.branches.2.blocks.0.gn3.bias", "deq.func.branches.3.blocks.0.conv1.weight_g", "deq.func.branches.3.blocks.0.conv1.weight_v", "deq.func.branches.3.blocks.0.gn1.weight", "deq.func.branches.3.blocks.0.gn1.bias", "deq.func.branches.3.blocks.0.conv2.weight_g", "deq.func.branches.3.blocks.0.conv2.weight_v", "deq.func.branches.3.blocks.0.gn2.weight", "deq.func.branches.3.blocks.0.gn2.bias", "deq.func.branches.3.blocks.0.gn3.weight", "deq.func.branches.3.blocks.0.gn3.bias", "deq.func.fuse_layers.0.1.net.conv.weight", "deq.func.fuse_layers.0.1.net.gnorm.weight", "deq.func.fuse_layers.0.1.net.gnorm.bias", "deq.func.fuse_layers.0.2.net.conv.weight", "deq.func.fuse_layers.0.2.net.gnorm.weight", "deq.func.fuse_layers.0.2.net.gnorm.bias", "deq.func.fuse_layers.0.3.net.conv.weight", "deq.func.fuse_layers.0.3.net.gnorm.weight", "deq.func.fuse_layers.0.3.net.gnorm.bias", "deq.func.fuse_layers.1.0.net.0.conv.weight", "deq.func.fuse_layers.1.0.net.0.gnorm.weight", "deq.func.fuse_layers.1.0.net.0.gnorm.bias", "deq.func.fuse_layers.1.2.net.conv.weight", "deq.func.fuse_layers.1.2.net.gnorm.weight", "deq.func.fuse_layers.1.2.net.gnorm.bias", "deq.func.fuse_layers.1.3.net.conv.weight", "deq.func.fuse_layers.1.3.net.gnorm.weight", "deq.func.fuse_layers.1.3.net.gnorm.bias", "deq.func.fuse_layers.2.0.net.0.conv.weight", "deq.func.fuse_layers.2.0.net.0.gnorm.weight", "deq.func.fuse_layers.2.0.net.0.gnorm.bias", "deq.func.fuse_layers.2.0.net.1.conv.weight", "deq.func.fuse_layers.2.0.net.1.gnorm.weight", "deq.func.fuse_layers.2.0.net.1.gnorm.bias", "deq.func.fuse_layers.2.1.net.0.conv.weight", "deq.func.fuse_layers.2.1.net.0.gnorm.weight", "deq.func.fuse_layers.2.1.net.0.gnorm.bias", "deq.func.fuse_layers.2.3.net.conv.weight", "deq.func.fuse_layers.2.3.net.gnorm.weight", "deq.func.fuse_layers.2.3.net.gnorm.bias", "deq.func.fuse_layers.3.0.net.0.conv.weight", "deq.func.fuse_layers.3.0.net.0.gnorm.weight", "deq.func.fuse_layers.3.0.net.0.gnorm.bias", "deq.func.fuse_layers.3.0.net.1.conv.weight", "deq.func.fuse_layers.3.0.net.1.gnorm.weight", "deq.func.fuse_layers.3.0.net.1.gnorm.bias", "deq.func.fuse_layers.3.0.net.2.conv.weight", "deq.func.fuse_layers.3.0.net.2.gnorm.weight", "deq.func.fuse_layers.3.0.net.2.gnorm.bias", "deq.func.fuse_layers.3.1.net.0.conv.weight", "deq.func.fuse_layers.3.1.net.0.gnorm.weight", "deq.func.fuse_layers.3.1.net.0.gnorm.bias", "deq.func.fuse_layers.3.1.net.1.conv.weight", "deq.func.fuse_layers.3.1.net.1.gnorm.weight", "deq.func.fuse_layers.3.1.net.1.gnorm.bias", "deq.func.fuse_layers.3.2.net.0.conv.weight", "deq.func.fuse_layers.3.2.net.0.gnorm.weight", "deq.func.fuse_layers.3.2.net.0.gnorm.bias", "deq.func.post_fuse_layers.0.conv.weight_g", "deq.func.post_fuse_layers.0.conv.weight_v", "deq.func.post_fuse_layers.0.gnorm.weight", "deq.func.post_fuse_layers.0.gnorm.bias", "deq.func.post_fuse_layers.1.conv.weight_g", "deq.func.post_fuse_layers.1.conv.weight_v", "deq.func.post_fuse_layers.1.gnorm.weight", "deq.func.post_fuse_layers.1.gnorm.bias", "deq.func.post_fuse_layers.2.conv.weight_g", "deq.func.post_fuse_layers.2.conv.weight_v", "deq.func.post_fuse_layers.2.gnorm.weight", "deq.func.post_fuse_layers.2.gnorm.bias", "deq.func.post_fuse_layers.3.conv.weight_g", "deq.func.post_fuse_layers.3.conv.weight_v", "deq.func.post_fuse_layers.3.gnorm.weight", "deq.func.post_fuse_layers.3.gnorm.bias", "deq.func_copy.branches.0.blocks.0.conv1.weight", "deq.func_copy.branches.0.blocks.0.gn1.weight", "deq.func_copy.branches.0.blocks.0.gn1.bias", "deq.func_copy.branches.0.blocks.0.conv2.weight", "deq.func_copy.branches.0.blocks.0.gn2.weight", "deq.func_copy.branches.0.blocks.0.gn2.bias", "deq.func_copy.branches.0.blocks.0.gn3.weight", "deq.func_copy.branches.0.blocks.0.gn3.bias", "deq.func_copy.branches.1.blocks.0.conv1.weight", "deq.func_copy.branches.1.blocks.0.gn1.weight", "deq.func_copy.branches.1.blocks.0.gn1.bias", "deq.func_copy.branches.1.blocks.0.conv2.weight", "deq.func_copy.branches.1.blocks.0.gn2.weight", "deq.func_copy.branches.1.blocks.0.gn2.bias", "deq.func_copy.branches.1.blocks.0.gn3.weight", "deq.func_copy.branches.1.blocks.0.gn3.bias", "deq.func_copy.branches.2.blocks.0.conv1.weight", "deq.func_copy.branches.2.blocks.0.gn1.weight", "deq.func_copy.branches.2.blocks.0.gn1.bias", "deq.func_copy.branches.2.blocks.0.conv2.weight", "deq.func_copy.branches.2.blocks.0.gn2.weight", "deq.func_copy.branches.2.blocks.0.gn2.bias", "deq.func_copy.branches.2.blocks.0.gn3.weight", "deq.func_copy.branches.2.blocks.0.gn3.bias", "deq.func_copy.branches.3.blocks.0.conv1.weight", "deq.func_copy.branches.3.blocks.0.gn1.weight", "deq.func_copy.branches.3.blocks.0.gn1.bias", "deq.func_copy.branches.3.blocks.0.conv2.weight", "deq.func_copy.branches.3.blocks.0.gn2.weight", "deq.func_copy.branches.3.blocks.0.gn2.bias", "deq.func_copy.branches.3.blocks.0.gn3.weight", "deq.func_copy.branches.3.blocks.0.gn3.bias", "deq.func_copy.fuse_layers.0.1.net.conv.weight", "deq.func_copy.fuse_layers.0.1.net.gnorm.weight", "deq.func_copy.fuse_layers.0.1.net.gnorm.bias", "deq.func_copy.fuse_layers.0.2.net.conv.weight", "deq.func_copy.fuse_layers.0.2.net.gnorm.weight", "deq.func_copy.fuse_layers.0.2.net.gnorm.bias", "deq.func_copy.fuse_layers.0.3.net.conv.weight", "deq.func_copy.fuse_layers.0.3.net.gnorm.weight", "deq.func_copy.fuse_layers.0.3.net.gnorm.bias", "deq.func_copy.fuse_layers.1.0.net.0.conv.weight", "deq.func_copy.fuse_layers.1.0.net.0.gnorm.weight", "deq.func_copy.fuse_layers.1.0.net.0.gnorm.bias", "deq.func_copy.fuse_layers.1.2.net.conv.weight", "deq.func_copy.fuse_layers.1.2.net.gnorm.weight", "deq.func_copy.fuse_layers.1.2.net.gnorm.bias", "deq.func_copy.fuse_layers.1.3.net.conv.weight", "deq.func_copy.fuse_layers.1.3.net.gnorm.weight", "deq.func_copy.fuse_layers.1.3.net.gnorm.bias", "deq.func_copy.fuse_layers.2.0.net.0.conv.weight", "deq.func_copy.fuse_layers.2.0.net.0.gnorm.weight", "deq.func_copy.fuse_layers.2.0.net.0.gnorm.bias", "deq.func_copy.fuse_layers.2.0.net.1.conv.weight", "deq.func_copy.fuse_layers.2.0.net.1.gnorm.weight", "deq.func_copy.fuse_layers.2.0.net.1.gnorm.bias", "deq.func_copy.fuse_layers.2.1.net.0.conv.weight", "deq.func_copy.fuse_layers.2.1.net.0.gnorm.weight", "deq.func_copy.fuse_layers.2.1.net.0.gnorm.bias", "deq.func_copy.fuse_layers.2.3.net.conv.weight", "deq.func_copy.fuse_layers.2.3.net.gnorm.weight", "deq.func_copy.fuse_layers.2.3.net.gnorm.bias", "deq.func_copy.fuse_layers.3.0.net.0.conv.weight", "deq.func_copy.fuse_layers.3.0.net.0.gnorm.weight", "deq.func_copy.fuse_layers.3.0.net.0.gnorm.bias", "deq.func_copy.fuse_layers.3.0.net.1.conv.weight", "deq.func_copy.fuse_layers.3.0.net.1.gnorm.weight", "deq.func_copy.fuse_layers.3.0.net.1.gnorm.bias", "deq.func_copy.fuse_layers.3.0.net.2.conv.weight", "deq.func_copy.fuse_layers.3.0.net.2.gnorm.weight", "deq.func_copy.fuse_layers.3.0.net.2.gnorm.bias", "deq.func_copy.fuse_layers.3.1.net.0.conv.weight", "deq.func_copy.fuse_layers.3.1.net.0.gnorm.weight", "deq.func_copy.fuse_layers.3.1.net.0.gnorm.bias", "deq.func_copy.fuse_layers.3.1.net.1.conv.weight", "deq.func_copy.fuse_layers.3.1.net.1.gnorm.weight", "deq.func_copy.fuse_layers.3.1.net.1.gnorm.bias", "deq.func_copy.fuse_layers.3.2.net.0.conv.weight", "deq.func_copy.fuse_layers.3.2.net.0.gnorm.weight", "deq.func_copy.fuse_layers.3.2.net.0.gnorm.bias", "deq.func_copy.post_fuse_layers.0.conv.weight", "deq.func_copy.post_fuse_layers.0.gnorm.weight", "deq.func_copy.post_fuse_layers.0.gnorm.bias", "deq.func_copy.post_fuse_layers.1.conv.weight", "deq.func_copy.post_fuse_layers.1.gnorm.weight", "deq.func_copy.post_fuse_layers.1.gnorm.bias", "deq.func_copy.post_fuse_layers.2.conv.weight", "deq.func_copy.post_fuse_layers.2.gnorm.weight", "deq.func_copy.post_fuse_layers.2.gnorm.bias", "deq.func_copy.post_fuse_layers.3.conv.weight", "deq.func_copy.post_fuse_layers.3.gnorm.weight", "deq.func_copy.post_fuse_layers.3.gnorm.bias".

Thanks for bringing this up. Yes, you'll need to process the small model in the same way as the XL model using that script I provided above. I'll make sure to check and update all of the pretrained models in the next few days!

Two ways to fix this:

  1. Re-download the pretrained model. I just uploaded a version that I tested to be okay (getting 79% on ImageNet). The new link is in the README.
  2. Run the following lines with path setting to your current mdeq_XL_cls.pth:
import os
import torch

pd = torch.load('pretrained_models/mdeq_XL.pth')
new_pd = {}
for k in pd:
    if "copy" not in k and "deq" not in k:
            new_pd[k] = pd[k].clone().detach().cpu()
            count += pd[k].nelement()

torch.save(new_pd, f'pretrained_models/mdeq_XL_new.pth')

And then use mdeq_XL_new.pth.