/testax

Jit-able runtime assertions for JAX in NumPy style.

Primary LanguagePythonApache License 2.0Apache-2.0

🧪 testax

https://img.shields.io/pypi/v/testax https://readthedocs.org/projects/testax/badge/?version=latest

testax provides runtime assertions for JAX through the testing interface familiar to NumPy users.

>>> import jax
>>> from jax import numpy as jnp
>>> import testax
>>>
>>> def safe_log(x):
...     testax.assert_array_less(0, x)
...     return jnp.log(x)
>>>
>>> safe_log(jnp.arange(2))
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered
<BLANKLINE>
Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
 x: Array(0, dtype=int32, weak_type=True)
 y: Array([0, 1], dtype=int32)

testax assertions are jit-able, although errors need to be functionalized to conform to JAX's requirement that functions are pure and do not have side effects (see the checkify guide for details). In short, a checkify-d function returns a tuple (error, value). The first element is an error that may have occurred, and the second is the return value of the original function.

>>> jitted = jax.jit(safe_log)
>>> checkified = testax.checkify(jitted)
>>> error, y = checkified(jnp.arange(2))
>>> error.throw()
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered
<BLANKLINE>
Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
 x: Array(0, dtype=int32, weak_type=True)
 y: Array([0, 1], dtype=int32)
>>> y
Array([-inf,   0.], dtype=float32)

Installation

testax is pip-installable and can be installed by running

pip install testax

Interface

testax mirrors the testing interface familiar to NumPy users, such as assert_allclose.