lukasHoel/novel-view-synthesis

Detailed Specification for Checkpoints Required

Opened this issue · 4 comments

  1. There are two main options to implement checkpoints with PyTorch:
    a) torch.save(model, PATH) & torch.load(PATH): Saves/loads the entire model as a serialized object to/from disk. Serialized data is bound to the specific classes and the exact directory structure used when the model is saved.
    b) torch.save(model.state_dict(), PATH) & model.load_state_dict(torch.load(PATH)): Only the model’s learned parameters are saved. This is the recommended way by docs and it is used in the notebook.

    Both options seem to have advantages and disadvantages. While the first one requires fixed a class definition and directory structure, for the 2nd option, I think we additionally need to save the architecture skeleton (input-output channel, etc. per layer) and instantiate the model with this skeleton first to avoid a mismatch between model architecture and saved parameters. With the first option, this information would already be available. Which one would you prefer?

  2. What information is likely to be needed to continue training from where it was stopped. In addition to model and optimizer, I thought of these: epoch, loss functions, and accuracy metrics used in solver during training. Do we need more information?

  1. I would prefer variant b) - this is already implemented in the jupyter notebook for full training (last cells). You are right that this requires the skeleton to be defined, but this way we can also make sure that the weights are only used with the correct model definition from us. So I would consider this to be a feature, actually :)

  2. I think we might also need information about the iteration in an epoch from which we should continue, if stopped in the middle of an epoch? I think stopping should only be ever possible after applying the current backward pass, so we would never need to save the current computational graph as well? Also, we can discuss if resuming training should continue logging into the same tensorboard files, if yes we would need information about their names/paths.

Do you know about example implementations for checkpointing training? Maybe we can take inspiration regarding best practices from other codebases :)

  1. Agreed. OK, then I will work on saving weights and corresponding layer metadata.
  2. Hmm, yes, there are many other aspects right now. My thought was making checkpoints only after an epoch is completed because I think it gets problematic especially when we allow shuffling the order of batches passed to the network in an epoch. If we allow checkpoints in the middle of an epoch then we may pass the same batch multiple times once we restart training. However, if we wait for an epoch to stop execution then this might take longer than desired. About logging, I think we can also store and restore folder path information and continue logging onto the same files.

I will see some examples. I think this would help resolve some of the points we mentioned.

I have checked several examples. They are all very similar. The common practice seems to be taking checkpoints with a frequency in terms of epochs and saving the model only if the model to be checkpointed achieves the max accuracy on validation until that moment. I can integrate both approaches and put an option to enable/disable the max accuracy condition for saving. So, I currently plan to adapt the following implementation to our case:

Yes, I think this implementation example looks good! I think it only misses the tensorboard directory to resume tensorboard logging.

I think we will then need new parameters or new methods in the solver classes to resume training?

  • Something like solver.train(..., checkpoint={}) where checkpoint is the data that we will save. And if the solver gets that argument, it knows how to continue training and also loads the model + optimizer that is passed to it. Currently, the tensorboard SummaryWriter gets created in the solver __init__, so we need to pass the correct log_dir parameter from the checkpoint when initializing the new solver.

  • Alternatively, we can provide a brand-new method that handles training-continuation.

What do you think would be best?