The jax-triton
repository contains integrations between JAX and Triton.
This is not an officially supported Google product.
$ pip install jax-triton
Make sure you have a CUDA-compatible jaxlib
installed.
For example you could run:
$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
To develop jax-triton
, you can clone the repo with:
$ git clone https://github.com/jax-ml/jax-triton.git
and do an editable install with:
$ cd jax-triton
$ pip install -e .
To run the jax-triton
tests, you'll need pytest
and absl-py
:
$ pip install pytest absl-py
$ pytest tests/