/jax-experimental

JAX for Graphcore IPU (experimental)

Primary LanguagePythonApache License 2.0Apache-2.0

logo

🔴 Non-official experimental 🔴 JAX on Graphcore IPU

Run on Gradient Continuous integration

Install guide | Quickstart | IPU JAX on Paperspace | Documentation

🔴 ⚠️ Non-official experimental ⚠️ 🔴

This is a very thin fork of http://github.com/google/jax for Graphcore IPU. This package is provided by Graphcore Research for experimentation purposes only, not production (inference or training).

Features and limitations of experimental JAX on IPUs

The following features are supported:

  • Vanilla JAX API: no additional IPU specific API, any code written for IPUs is backward compatible with other backends (CPU/GPU/TPU);
  • JAX asynchronous dispatch on IPU backend;
  • Multiple IPUs with collectives using pmap and (experimental) pjit;
  • Large coverage of JAX lax operators;
  • Support of JAX buffer donation to keep parameters on IPU SRAM;

Known limitations of the project:

  • No eager mode (every JAX call has to be compiled, loaded, and finally executed on IPU device);
  • IPU code generated can be larger than official Graphcore TensorFlow or PopTorch (limiting batch size or model size);
  • Multi-IPUs collective have topology restrictions (following Graphcore GCL API);
  • Missing linear algebra operators;
  • Incomplete support of JAX random number generation on IPU device;
  • Deactivated support of JAX infeeds and outfeeds;

This is a research project, not an official Graphcore product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

Installation

The experimental JAX wheels require Ubuntu 20.04, Graphcore Poplar SDK 3.1 or 3.2 and Python 3.8, and can be installed as following:

pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

For SDK 3.2, please change jaxlib version to jaxlib==0.3.15+ipu.sdk320.

Minimal example

The following example can be run on Graphcore IPU Paperspace (or on a non-IPU machine using the IPU emulator):

from functools import partial
import jax
import numpy as np

@partial(jax.jit, backend="ipu")
def ipu_function(data):
    return data**2 + 1

data = np.array([1, -2, 3], np.float32)
output = ipu_function(data)
print(output, output.device())

JAX on IPU Paperspace notebooks

Additional JAX on IPU examples:

Useful JAX backend flags:

As standard in JAX, these flags can be set using from jax.config import config import.

Flag Description
config.FLAGS.jax_platform_name ='ipu'/'cpu' Configure default JAX backend. Useful for CPU initialization.
config.FLAGS.jax_ipu_use_model = True Use IPU model emulator.
config.FLAGS.jax_ipu_model_num_tiles = 8 Set the number of tiles in the IPU model.
config.FLAGS.jax_ipu_device_count = 2 Set the number of IPUs visible in JAX. Can be any local IPU available.
config.FLAGS.jax_ipu_visible_devices = '0,1' Set the specific collection of local IPUs to be visible in JAX.

Alternatively, like other JAX flags, these can be set using environment variables (e.g. JAX_IPU_USE_MODEL, JAX_IPU_MODEL_NUM_TILES,...).

Useful PopVision environment variables:

  • Generate PopVision Graph analyser profile: POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true", "debug.allowOutOfMemory":"true"}'
  • Generate PopVision system analyser profile: PVTI_OPTIONS='{"enable":"true", "directory":"./reports"}'

Documentation

License

The project remains licensed under the Apache License 2.0, with the following files unchanged:

The additional dependencies introduced for Graphcore IPU support are: