/easyjax

🧱 A package that facilitates machine learning development in JAX containing several machine learning utilities and a trainer parent class.

Primary LanguagePython

EasyJax

work-in-progress.

EasyJax is a python package that facilitates machine learning development for JAX. It does that by providing:

  1. A high-level API for machine learning workflows in JAX (specifically a trainer, experiment parent class).
  2. Several machine learning specific utilities for working with JAX (e.g., ml.update_step).