A set of utilities for retaining sanity while managing and monitoring ML experiments that utilise models written in JAX.
This is for my own personal use as I'm trying to build a tech stack that allows me to rapidly conduct RL experiments for my doctorate. I plan to add more stuff to this as I need to.
HEAVILY inspired by & based on the code of torchkit. My main goal was to essentially replicate it but make it compatible with JAX patterns.
Currently has the following utilities:
mission_control.checkpoint
: Provides aCheckpointManager
class that can be used to save & loadCheckpoint
s of arbitrary PyTrees, e.g. haiku Params, optax OptStates, JAX arrays, numPy arrays and any other PyTree whose leaves can be serialized withnp.save
. The interface just requires that you provide PyTrees as kwargs.Note that the solution to checkpointing I went with uses
pickle
to save thetreedef
and is thus far from ideal. However, it was simple enough and it will do for my usecase,and it also seems to be used by other practicioners. For actual "prod" usecases Orbax is superior since it actually serializes things to JSON.mission_control.loggers
: Provides aLogger
interface for logging common training artifacts such as metrics, images and videos. Currently supports logging to Weights & Biases and Tensorboard withWandbLogger
andTensorboardLogger
.
pip install "git+https://github.com/evangelos-ch/mission-control.git"
If you want to use the loggers, you need to install the required extras (either wandb
or tensorboard
). For example:
pip install "mission-control[wandb] @ git+https://github.com/evangelos-ch/mission-control.git"