Error loading checkpoint for vit_b_sam backbone
Closed this issue · 0 comments
themattinthehatt commented
When loading weights from a fine-tuned vit_b_sam backbone, if the fine-tuning frame size is not 1024x1024 the following error is raised:
RuntimeError: Error(s) in loading state_dict for HeatmapTracker:
size mismatch for backbone.pos_embed: copying a param with shape torch.Size([1, 16, 16, 768]) from checkpoint, the shape in current model is torch.Size([1, 64, 64, 768]).
The problem:
- During training, the regular vit_b_sam backbone is constructed, which assumes an image shape of 1024x1024
- If the image size that we are fine-tuning on is not 1024x1024, the position embedding is automatically updated during training and the new weights are stored (and eventually saved)
- When loading the weights into a new model, the position embedding parameter assuming 1024x1024 is constructed, but the saved parameter assuming a different image size is loaded in (with the above error).
The solution:
Instead of loading the state dict directly into the model using Model.load_from_checkpoint
, this step needs to be broken into several parts:
- Initialize the model (which includes loading the SAM weights) - this will set the position embedding parameter in a way that assumes 1024x1024 images
- Manually update the position embedding parameter to match the desired fine-tune image size
- Load the weights from the checkpoint