OpenPipe/ART

Feature request - load best checkpoint

Opened this issue · 2 comments

Currently there is a function TrainableModel.delete_checkpoints(best_checkpoint_metric) which removes all checkpoints except the best and the latest.

Unfortunately, there is no straight-forward way to load the weights according to the best checkpoint.
It would be nice if such function existed.

Example signature:

class TrainableModel:
    def load_checkpoint(which: int | Literal["best"] | Literal["latest"] = "latest", best_checkpoint_metric: str = "val/reward"):
        """
        Args: 
          which (int | Literal["best"] | Literal["latest"]) - The type of checkpoint to load.
            - "best" loads the best checkpoint according to the `best_checkpoint_metric"
            - "latest" loads the latest checkpoint available
            - integer value determines the step number of the checkpoint.
          best_checkpoint_metric (str) - the name of the metric determining which checkpoint is best.
          """
          ...

For training?

Both for training and evaluation:

  • For evaluation purposes my use case is performing a full evaluation after the training is finished. Currently i have to look in the logs for the best performing iteration, locate the specific LORA path, manually load the checkpoint and only then evaluate.

  • For training purposes - This option provides an easy way to continue training from a past run not only from the latest checkpoint but from the BEST one (without manually looking at the logs at which step the best checkpoint was achieved).

I personally feel that this feature complements nicely the functions model.log and model.delete_checkpoints which abstracts away manual checkpoint and file management.