/jax-port

Port numpy (and some scipy) code to JAX

Primary LanguagePythonApache License 2.0Apache-2.0

jax-port

CLI script for porting code using numpy and scipy to their jax equivalent

Usage

The script only uses the standard library so the should be no further setup needed if you already have python installed Note that only python 3.9 and later are supported due to the use of ast.unparse

python jax_port.py -i some_numpy_code.py > some_jax_code.py

Roadmap