/jax-triton

jax-triton contains integrations between JAX and OpenAI Triton

Primary LanguagePythonApache License 2.0Apache-2.0

jax-triton

The jax-triton repository contains integrations between JAX and Triton.

This is not an officially supported Google product.

Installation

$ pip install jax-triton

Make sure you have a CUDA-compatible jaxlib installed. For example you could run:

$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Development

To develop jax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$ cd jax-triton
$ pip install -e .

To run the jax-triton tests, you'll need pytest and absl-py:

$ pip install pytest absl-py
$ pytest tests/