lightly-ai/lightly

Add function to load model from pretrained checkpoint

Opened this issue · 0 comments

We should add a function to load backbones from the benchmark checkpoints. The function should roughly do the following:

from torchvision.models import resnet50
from torch.hub import load_state_dict_from_url

model = resnet50()
state_dict = load_state_dict_from_url("https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_simclr_2023-06-22_09-11-13/pretrain/version_0/checkpoints/epoch%3D99-step%3D500400.ckpt")
new_state_dict = {}
for key, value in state_dict["state_dict"].items():
     if key.startswith("backbone."):
        new_state_dict[key.lstrip("backbone.")] = value
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict)
assert missing_keys == {"fc.weight", "fc.bias"}

Maybe we can leave the load_state_dict_from_url outside the function make the function just take a state dict as input and return the new state dict as output.

TODO

  • Add function
  • Document function