EdinburghNLP/nematus

How to continue training from the last checkpoint

LinuxBeginner opened this issue · 5 comments

Hi, I am training the model in colab, so I will need to retrained the model.
I used the following training script to train the model:

THEANO_FLAGS=mode=FAST_RUN,floatX=float32,device=$device,gpuarray.preallocate=0.8 python $nematus_home/nematus/train.py \
    --model $working_dir/model.npz \
    --datasets $data_dir/train.bpe.$src $data_dir/train.bpe.$trg \
    --valid_datasets $data_dir/dev.bpe.$src $data_dir/dev.bpe.$trg \
    --dictionaries $data_dir/train.bpe.$src.json $data_dir/train.bpe.$trg.json \
    --valid_script $script_dir/validate.sh \
    --dim_word 512 \
    --dim 1024 \
    --lrate 0.0001 \
    --optimizer adam \
    --maxlen 50 \
    --batch_size 80 \
    --valid_batch_size 40 \
    --validFreq 10000 \
    --dispFreq 1000 \
    --saveFreq 10000 \
    --sampleFreq 10000 \
    --tie_decoder_embeddings \
    --layer_normalisation \
    --dec_base_recurrence_transition_depth 8 \
    --enc_recurrence_transition_depth 4

After a while, the training stopped because I used my daily 12 hours limit in colab. I was able to save only one model (model.npz.data-00000-of-00001):

enmn.bpe		       model.npz.json		truecase-model.en
model.npz.data-00000-of-00001  model.npz.meta		truecase-model.mn
model.npz.index		       model.npz.progress.json

As per the comments here the parameter --model $working_dir/model.npz should do the work.

So, without any changes, when I trained the above script, the training starts again from epoch 0.

I also tried by adding
--reload $working_dir/model.npz.data-00000-of-00001 in the script, but it gave me an error.

which parameter should I use to continue training from the last checkpoint?

using "--reload $working_dir/model.npz" should do the trick.

On a side note, make sure to save checkpoints frequently enough ("--save_freq" and maybe "--valid_freq as well") that you're not wasting a lot of compute.

@rsennrich Thank you very much. It's working now.

Suppose, if I want to train domain adaptation, should I use the same reload script?
And what else should I add, is there any other parameter such as --no-restore-corpus in the Marian model if I happen to change the corpus?

yes, you can do domain adaptation with the same command, simply switching out the --datasets (and --valid_datasets) parameters.

no, there is no need to add other parameters when switching out the corpus.

Hi @rsennrich ,

So, I train the system from the start, with the following script

THEANO_FLAGS=mode=FAST_RUN,floatX=float32,device=$device,gpuarray.preallocate=0.8 python $nematus_home/nematus/train.py \
    --model $working_dir/model.npz \
    --datasets $data_dir/train.bpe.$src $data_dir/train.bpe.$trg \
    --valid_datasets $data_dir/dev.bpe.$src $data_dir/dev.bpe.$trg \
    --dictionaries $data_dir/train.bpe.$src.json $data_dir/train.bpe.$trg.json \
    --dim_word 512 \
    --dim 1024 \
    --lrate 0.0001 \
    --optimizer adam \
    --maxlen 50 \
    --batch_size 80 \
    --valid_batch_size 40 \
    --validFreq 2000 \
    --dispFreq 1000 \
    --saveFreq 4000 \
    --sampleFreq 5000 \
    --tie_decoder_embeddings \
    --layer_normalisation \
    --dec_base_recurrence_transition_depth 8 \
    --enc_recurrence_transition_depth 4

it ran up to 65 epoch. ( Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 104.05570943714682 45680.45644290745 439 12761 )
On retraining it, the script seems to pick up the epoch where the average value AVG of
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS):
was the lowest, which in my case was at epoch 21
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 93.52767660609682 41058.650030076504 439 12761

So, on retraining, the system starts to train from epoch 21. Is there any way where I could retrain from the latest save checkpoint instead of the lowest AVG of Validation cross entropy?

Log summary of training from scratch:

2020-10-06 18:57:26.056344: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
INFO: Initializing model parameters from scratch...
INFO: Done
INFO: Reading data...
INFO: Done
INFO: Initial uidx=0
INFO: Starting epoch 0
Epoch: 5 Update: 500 Loss/word: 4.50615650551452 Words/sec: 560.6486453034273 Sents/sec: 18.544880056627484
Epoch: 10 Update: 1000 Loss/word: 3.0483226237200056 Words/sec: 597.8954025054368 Sents/sec: 19.811542751676704
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 98.99814292781717 43460.18474531174 439 12761
INFO: nematus/scripts/../model/tmpggran0_o/model.npz is not in all_model_checkpoint_paths. Manually adding it.
Epoch: 15 Update: 1500 Loss/word: 2.253776026582657 Words/sec: 613.7468884867649 Sents/sec: 20.39968236473054
Epoch: 21 Update: 2000 Loss/word: 1.5804471220434453 Words/sec: 616.3454607128992 Sents/sec: 20.497640055614173
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 93.52767660609682 41058.650030076504 439 12761
INFO: nematus/scripts/../model/tmp2qs2c2me/model.npz is not in all_model_checkpoint_paths. Manually adding it.
INFO: nematus/scripts/../model/model.npz-2000 is not in all_model_checkpoint_paths. Manually adding it.
Epoch: 26 Update: 2500 Loss/word: 1.0360141398181792 Words/sec: 585.9223627750297 Sents/sec: 19.421033707411443
Epoch: 31 Update: 3000 Loss/word: 0.6585676208047083 Words/sec: 608.1438455339212 Sents/sec: 20.217283316325187
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 94.28880919170712 41392.78723515943 439 12761
Epoch: 36 Update: 3500 Loss/word: 0.40650631285085853 Words/sec: 611.7562059492847 Sents/sec: 20.416566209876706
Epoch: 42 Update: 4000 Loss/word: 0.25016178405585543 Words/sec: 623.6446264420395 Sents/sec: 20.594209689107416
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 97.6569819442232 42871.415073513985 439 12761
NFO: nematus/scripts/../model/model.npz-4000 is not in all_model_checkpoint_paths. Manually adding it.
Epoch: 47 Update: 4500 Loss/word: 0.1609760880468487 Words/sec: 613.71050013876 Sents/sec: 20.408149789038234
Epoch: 52 Update: 5000 Loss/word: 0.10791939750585201 Words/sec: 600.6070125173119 Sents/sec: 20.031917560592202
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 103.0881025204244 45255.677006466314 439 12761
Epoch: 57 Update: 5500 Loss/word: 0.07942602140902875 Words/sec: 597.0705878833814 Sents/sec: 19.769243517218847
Validation cross entropy (AVG/SUM/N_SENTS/N_TOKENS): 104.05570943714682 45680.45644290745 439 12761
INFO: nematus/scripts/../model/model.npz-6000 is not in all_model_checkpoint_paths. Manually adding it.
INFO: Starting epoch 64
INFO: Starting epoch 65

Also, the Validation cross entropy AVG score decrease only once at epoch 21, afterwards it keeps on increasing. However, the Loss/word: value keeps on decreasing.

model.npz is generally the checkpoint with the lowest validation perplexity. If you want to continue training from the latest checkpoint, you can use --reload model.npz-XXXX (where XXXX is the number of the latest checkpoint).

if you train a base model and then continue fine-tuning it on some other data, I also find it a good practice to change the working directory (and possibly the model name) so I don't unintentionally overwrite the base model.