abhishekkrthakur/tez

Resuming the training through checkpoint with tez

vikas-nexcom opened this issue · 1 comments

Hi,

I am wondering if it is possible to pick up a saved model and resume/continue the training with tez. I am new to pytorch. Here is what I tried:

class Bert(tez.Model):

    def __init__(self, num_classes, num_train_steps=None):

        super().__init__()
        self.bert = transformers.BertModel.from_pretrained(
           'bert-base-uncased, 
            return_dict=False
            )

        if config.RETRAINING: # set to True
            self.bert.load(
            'demo.bin', 
            device='cuda')    

        self.bert_drop = nn.Dropout(0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, num_classes)

and it doesn't work. I am not sure what I am missing. I found this for pytorch:

https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244

but I am not sure how to use this together with tez.

Sure, just do model.load() and you can re-train. you might also want to load the state of optimizer and scheduler. ill add support for saving them.