UX-Decoder/Semantic-SAM

model and ckpt don't match

jiang25262 opened this issue · 3 comments

Thanks for your great work!
Why there are some weights don't match?

*UNLOADED* norm0.bias, Model Shape: torch.Size([192])
*UNLOADED* norm0.weight, Model Shape: torch.Size([192])
*UNLOADED* norm1.bias, Model Shape: torch.Size([384])
*UNLOADED* norm1.weight, Model Shape: torch.Size([384])
*UNLOADED* norm2.bias, Model Shape: torch.Size([768])
*UNLOADED* norm2.weight, Model Shape: torch.Size([768])
*UNLOADED* norm3.bias, Model Shape: torch.Size([1536])
*UNLOADED* norm3.weight, Model Shape: torch.Size([1536])
$UNUSED$ head.bias, Ckpt Shape: torch.Size([21841])
$UNUSED$ head.weight, Ckpt Shape: torch.Size([21841, 1536])
$UNUSED$ layers.0.blocks.1.attn_mask, Ckpt Shape: torch.Size([64, 144, 144])
$UNUSED$ layers.1.blocks.1.attn_mask, Ckpt Shape: torch.Size([16, 144, 144])
$UNUSED$ layers.2.blocks.1.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.11.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.13.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.15.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.17.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.3.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.5.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.7.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ layers.2.blocks.9.attn_mask, Ckpt Shape: torch.Size([4, 144, 144])
$UNUSED$ norm.bias, Ckpt Shape: torch.Size([1536])
$UNUSED$ norm.weight, Ckpt Shape: torch.Size([1536])

Here is my load code

def from_pretrained(self, load_dir):
    state_dict = torch.load(load_dir, map_location='cpu')
    if 'model' in state_dict:
        state_dict=state_dict['model']
    # state_dict={k[6:]:v for k,v in state_dict.items() if k.startswith('model.')}
    # for k in self.model.state_dict():
    #     if k not in state_dict:
    #         assert k[:-2] in state_dict
    #         state_dict[k]=state_dict.pop(k[:-2])
    state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict)
    self.model.load_state_dict(state_dict, strict=False)
    return self

def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
    model_state_dict={k[9:]:v for k,v in model_state_dict.items() if k.startswith('backbone.') }
    model_keys = sorted(model_state_dict.keys())
    ckpt_keys = sorted(ckpt_state_dict.keys())
    result_dicts = {}
    matched_log = []
    unmatched_log = []
    unloaded_log = []
    for model_key in model_keys:
        model_weight = model_state_dict[model_key]
        if model_key in ckpt_keys:
            ckpt_weight = ckpt_state_dict[model_key]
            if model_weight.shape == ckpt_weight.shape:
                result_dicts[model_key] = ckpt_weight
                ckpt_keys.pop(ckpt_keys.index(model_key))
                matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
            else:
                unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
        else:
            unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
             

mask_generator = SemanticSamAutomaticMaskGenerator(build_semantic_sam(model_type='L', ckpt='./ckps/swin_large_patch4_window12_384_22k.pth'))

I think i used the wrong model. I will close this issue. Sorry.