work-in-progress.
EasyJax is a python package that facilitates machine learning development for JAX. It does that by providing:
- A high-level API for machine learning workflows in JAX (specifically a trainer, experiment parent class).
- Several machine learning specific utilities for working with JAX (e.g.,
ml.update_step
).