KeyError: 'pos_embed' when fine-tuning
RooKichenn opened this issue · 3 comments
When I use the uniformer_base_ls model to fine tune on my own dataset, the code reports an error:
Traceback (most recent call last): File "main.py", line 502, in <module> main(args) File "main.py", line 278, in main pos_embed_checkpoint = checkpoint_model['pos_embed'] KeyError: 'pos_embed'
Then I checked the pre-training weights and found that the key pos_embed is not listed separately:
'blocks1.0.pos_embed.weight', 'blocks1.0.pos_embed.bias'
Finally I commented out these lines of code and the program worked. What is this problem? Can you help me please?
# interpolate position embedding
print(checkpoint_model.keys())
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
Sorry for the late reply.
Your practice is right! The above code for fine-tuning is copied from DeiT, which is used for interpolating the absolute position embedding in ViT.
Since our UniFormer uses dynamic position embedding, you can simply comment out the code!
Sorry I haven't read the DeiT source code, thank you very much for answering this question for me!
As there is no more activity, I am closing the issue, don't hesitate to reopen it if necessary.