Support saving the model so that it can be reloaded from Huggingface
dfioravanti opened this issue ยท 2 comments
๐ Feature
Right now at the end of training the model is saved as a ckpt
file and huggingface really does not like ckpt
files. From what I can tell there is no easy way to load a checkpoint into a HF model.
Motivation
Having this feature would make deploying a model much much easier as huggingface is becoming a somewhat standard library that quite a few people know and use
Pitch
It should be possible to save the best model in a format that HF can easily load
@SeanNaren wrote in the slack that this potentially can be done within the LightningModule (or the TransformerTask) in the on_save_checkpoint hook, however we have to make sure to load correctly as well.
I've currently resorted to monkey-patching to temporarily fix the problem. If you have a violent allergic reaction to code that has a tendency to break easily, please do not read the following:
class ModelSwitcher:
def __init__(self):
self.model = pl.LightningModule() # load model here
self.checkpoint_path = "path/to/model.ckpt"
def load_checkpoint(self):
old_state_dict = torch.load(self.checkpoint_path)["state_dict"]
# the checkpoint contains a dictionary of weights that can be
# directly mapped to the transformer weights, the names are just
# a little different, you might need to tweak this
for key in old_state_dict.keys():
try:
keys = key.split(".")[:-1]
module_name = (
"model._modules['"
+ "']._modules['".join(keys)
+ "'].weight.data"
)
setattr(self, module_name, old_state_dict[key])
except Exception as e:
print(e, f"could not find {key} in model")
I can confirm this works for pytorch-lightning type models; but since I'm doing manipulations at the pytorch level, it might work on other projects. You have to inspect your model names to make SURE the naming conventions match up. Furthermore, there's no check for the size of the input/output weights. Please don't take this post too seriously. This is only for the truly desperate.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.