/tjax

Tools for JAX

Primary LanguagePythonApache License 2.0Apache-2.0

Tools for JAX

PyPI - Version PyPI - Python Version

This repository implements a variety of tools for the differential programming library JAX.

Major components

Tjax's major components are:

  • A dataclass decorator dataclass that facilitates defining structured JAX objects (so-called "pytrees"), which benefits from:
    • the ability to mark fields as static (not available in chex.dataclass), and
    • a display method that produces formatted text according to the tree structure.
  • A shim for the gradient transformation library optax that supports:
    • easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
    • gradient transformation objects that can be passed dynamically to jitted functions, and
    • generic type annotations.
  • A pretty printer print_generic for aggregate and vector types, including dataclasses. (See display.) It features:
    • support for traced values,
    • colorized tree output for aggregate structures, and
    • formatted tabular output for arrays (or statistics when there's no room for tabular output).

Minor components

Tjax also includes:

  • Versions of custom_vjp and custom_jvp that support being used on methods: custom_vjp_method and custom_vjp_method (See shims.)
  • Tools for working with cotangents. (See cotangent_tools.)
  • JAX tree registration for NetworkX graph types. (See graph.)
  • Leaky integration leaky_integrate and Ornstein-Uhlenbeck process iteration diffused_leaky_integrate. (See leaky_integral.)
  • An improved version of jax.tree_util.Partial. (See partial.)
  • A testing function assert_tree_allclose that automatically produces testing code. And, a related function tree_allclose. (See testing.)
  • Basic tools like divide_where. (See tools.)

Contribution guidelines

  • Conventions: PEP8.
  • How to run tests: pytest .
  • How to clean the source:
    • ruff .
    • pyright
    • mypy
    • isort .
    • pylint tjax tests