bryanlimy/tf2-transformer-chatbot

Save and Restore Checkpoint

Closed this issue · 1 comments

How to save and restore checkpoint?

I try:

  checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
  checkpoint_dir = os.path.dirname(checkpoint_path)

  latest = tf.train.latest_checkpoint(checkpoint_dir)

  cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_path,
                                                   verbose = 1,
                                                   save_weights_only = True,
                                                   period = 5)

  model.fit(dataset,
            epochs=hparams.epochs,
            callbacks = [cp_callback])

This is save Checkpoint, but I can't restore it.

you need to call

model.load_weights(checkpoint_path)

to restore a checkpoint, you can check here https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models#save_checkpoints_during_training