Model from URl
ChrisDelClea opened this issue · 1 comments
ChrisDelClea commented
In tf one can download the model dirctly from an uri with just 2 lines of code.
It would be great, if i can also do it with pytorch:
import tensorflow_hub as hub
model = hub.KerasLayer("https://my.own.repo.de/artifactory/pytorch-proxy/my_model/2")
Instead of downloading it from github or others. Would that be possible?
NicolasHug commented
Hi @ChrisDelClea ,
The pytorch equivalent would be something like torch.hub.load_state_dict_from_url:
model = MyModel()
state_dict = torch.hub.load_state_dict_from_url(THE_URL)
model.load_state_dict(state_dict)
It's slightly different from TF as in pytorch, the model instantiation is decoupled from the loading of its weights. You'll find mode details in this tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html