metrax is a library with standard eval metrics implementations in JAX.
Install the package by installing the PyPi release.
pip install google-metrax
Run the tests:
pytest src/metraxDevelop the docs locally:
pip install -r ./docs/requirements.txt
sphinx-build ./docs /tmp/metrax_docs
python -m http.server --directory /tmp/metrax_docs