/JAX-Toolbox

JAX-Toolbox

Primary LanguagePythonApache License 2.0Apache-2.0

JAX Toolbox

Image Build Tests

container-badge-base

build-badge-base n/a
Frameworks
container-badge-jax build-badge-jax test-badge-jax-V100
test-badge-jax-A100
container-badge-te Included in JAX build unit-test-badge-te
integration-test-badge-te
container-badge-rosetta-t5x build-badge-rosetta-t5x
test-badge-t5x
test-badge-rosetta-t5x
container-badge-rosetta-pax build-badge-rosetta-pax
test-badge-pax
test-badge-rosetta-pax

Note

This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: T5x, PAXML, Transformer Engine, and others to come soon.

Supported Models

We currently enable training and evaluation for the following models:

Model Name Pretraining Fine-tuning Evaluation
GPT-3(paxml) ✔️ ✔️
t5(t5x) ✔️ ✔️ ✔️
ViT ✔️ ✔️ ✔️
Imagen ✔️ ✔️

We will update this table as new models become available, so stay tuned.

Environment Variables

The JAX image is embedded with the following flags and environment variables for performance tuning:

XLA Flags Value Explanation
--xla_gpu_enable_latency_hiding_scheduler true allows XLA to move communication collectives to increase overlap with compute kernels
--xla_gpu_enable_async_all_gather true allows XLA to run NCCL AllGather kernels on a separate CUDA stream to allow overlap with compute kernels
--xla_gpu_enable_async_reduce_scatter true allows XLA to run NCCL ReduceScatter kernels on a separate CUDA stream to allow overlap with compute kernels
--xla_gpu_enable_triton_gemm false use cuBLAS instead of Trition GeMM kernels
Environment Variable Value Explanation
CUDA_DEVICE_MAX_CONNECTIONS 1 use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches
NCCL_IB_SL 1 defines the InfiniBand Service Level (1)
NCCL_NVLS_ENABLE 0 Disables NVLink SHARP (1). Future releases will re-enable this feature.
CUDA_MODULE_LOADING EAGER Disables lazy-loading (1) which uses slightly more GPU memory.

FAQ (Frequently Asked Questions)

`bus error` when running JAX in a docker container

Solution:

docker run -it --shm-size=1g ...

Explanation: The bus error might occur due to the size limitation of /dev/shm. You can address this by increasing the shared memory size using the --shm-size option when launching your container.

enroot/pyxis reports error code 404 when importing multi-arch images

Problem description:

slurmstepd: error: pyxis:     [INFO] Authentication succeeded
slurmstepd: error: pyxis:     [INFO] Fetching image manifest list
slurmstepd: error: pyxis:     [INFO] Fetching image manifest
slurmstepd: error: pyxis:     [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/<TAG> returned error code: 404 Not Found

Solution: Upgrade enroot or apply a single-file patch as mentioned in the enroot v3.4.0 release note.

Explanation: Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists but has switched to using the Open Container Initiative (OCI) format since 20.10. Enroot added support for OCI format in version 3.4.0.

JAX on Public Clouds

Resources