Feature request - load best checkpoint
Opened this issue · 2 comments
Currently there is a function TrainableModel.delete_checkpoints(best_checkpoint_metric) which removes all checkpoints except the best and the latest.
Unfortunately, there is no straight-forward way to load the weights according to the best checkpoint.
It would be nice if such function existed.
Example signature:
class TrainableModel:
def load_checkpoint(which: int | Literal["best"] | Literal["latest"] = "latest", best_checkpoint_metric: str = "val/reward"):
"""
Args:
which (int | Literal["best"] | Literal["latest"]) - The type of checkpoint to load.
- "best" loads the best checkpoint according to the `best_checkpoint_metric"
- "latest" loads the latest checkpoint available
- integer value determines the step number of the checkpoint.
best_checkpoint_metric (str) - the name of the metric determining which checkpoint is best.
"""
...For training?
Both for training and evaluation:
-
For evaluation purposes my use case is performing a full evaluation after the training is finished. Currently i have to look in the logs for the best performing iteration, locate the specific LORA path, manually load the checkpoint and only then evaluate.
-
For training purposes - This option provides an easy way to continue training from a past run not only from the latest checkpoint but from the BEST one (without manually looking at the logs at which step the best checkpoint was achieved).
I personally feel that this feature complements nicely the functions model.log and model.delete_checkpoints which abstracts away manual checkpoint and file management.