pytorch/hub

Model from URl

ChrisDelClea opened this issue · 1 comments

In tf one can download the model dirctly from an uri with just 2 lines of code.
It would be great, if i can also do it with pytorch:

import tensorflow_hub as hub

model = hub.KerasLayer("https://my.own.repo.de/artifactory/pytorch-proxy/my_model/2")

Instead of downloading it from github or others. Would that be possible?

Hi @ChrisDelClea ,

The pytorch equivalent would be something like torch.hub.load_state_dict_from_url:

model = MyModel()
state_dict = torch.hub.load_state_dict_from_url(THE_URL)
model.load_state_dict(state_dict)

It's slightly different from TF as in pytorch, the model instantiation is decoupled from the loading of its weights. You'll find mode details in this tutorial: https://pytorch.org/tutorials/beginner/saving_loading_models.html