An attempt to do input pipelines purely relying on JAX, with support for jitting iterators. Very very heavily inspired by the tf.data.Dataset API, since this is
what most jax users currently use.
import jax.numpy as jnp
import jaxio
d = jaxio.Dataset.from_pytree_slices(jnp.arange(10))
d = d.as_jit_compatible() # -- jit boundary --
d = d.batch(3) # <-- will be jitted
d = d.map(jnp.square) # <-- will be jitted
d = d.map(lambda x: -x) # <-- will be jitted
d = d.jit() # -- jit boundary --
d = d.prefetch(1)
for el in d:
print(el)
# [ 0 -1 -4]
# [ -9 -16 -25]
# [-36 -49 -64]pip install jaxioThese instructions are only intended for those interested in contributing to jaxio directly.
One-time setup:
pip install 'build[virtualenv]' twine
virtualenv .venv
virtualenv .venv-docs
source .venv-docs/bin/activate
pip install --upgrade pip
pip install -r docs/requirements.txt
source .venv/bin/activate
pip install --upgrade pip
pip install -e .
pip install pytestTo test the package locally:
pytestTo re-generate the documentation pages locally:
source .venv-docs/bin/activate
pip install -e .
rm -rf docs/_build
cd docs && make html && cd -
source .venv/bin/activateNOTE: when pushing to main, readthedocs will re-build the docs based on the
latest version in pypi.
To upload to pypi:
deactivate 2> /dev/null
python -m build .
twine upload -r testpypi dist/* # try it out first
# twine upload dist/*