GRAAL-Research/deepparse

RuntimeError: Error(s) in loading state_dict for BPEmbSeq2SeqModel:

rbhatia46 opened this issue · 2 comments

Hi, I am trying to run the retrain script, below code to be precise -

# First, let's download the train and test data from the public repository.
saving_dir = "./model_data"
file_extension = "p"


# Now let's create a training and test container.
training_container = PickleDatasetContainer('india_train.p')
test_container = PickleDatasetContainer('india_test.p')

# We will retrain the fasttext version of our pretrained model.
address_parser = AddressParser(model_type="bpemb", device=0)

# Now let's retrain for 5 epochs using a batch size of 8 since the data is really small for the example.
# Let's start with the default learning rate of 0.01 and use a learning rate scheduler to lower the learning rate
# as we progress.
lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1)  # reduce LR by a factor of 10 each epoch

# The checkpoints (ckpt) are saved in the default "./checkpoints" directory.
address_parser.retrain(training_container, 0.8, epochs=5, batch_size=8, num_workers=2, callbacks=[lr_scheduler])

# Now let's test our fine tuned model using the best checkpoint (default parameter).
address_parser.test(test_container, batch_size=256)

When I run the above code with fasttext instead of bpemb, it works fine, but when I am running it with bpemb, I get the following error -

RuntimeError: Error(s) in loading state_dict for BPEmbSeq2SeqModel:
	Missing key(s) in state_dict: "embedding_network.model.weight_ih_l0", "embedding_network.model.weight_hh_l0", "embedding_network.model.bias_ih_l0", "embedding_network.model.bias_hh_l0", "embedding_network.model.weight_ih_l0_reverse", "embedding_network.model.weight_hh_l0_reverse", "embedding_network.model.bias_ih_l0_reverse", "embedding_network.model.bias_hh_l0_reverse", "embedding_network.projection_layer.weight", "embedding_network.projection_layer.bias". 

Have you already trained a model before?

This is due to the retrain method using the default checkpoints repository.
In version 0.6.2, we handle better this kind of error. Also, the doc includes a note about it, and the loggin_path parameter also include a comment about it.