Jax NaN
This repo contains code to reproduce a tricky NaN I found while training a neural network. Its contains some slightly modified code from T5X and a script to reproduce the issue.
Setup
On a TPU v3-8, install following T5x:
python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Download the checkpoint from gs://chrisc-public/debug-checkpoint:
gsutil -m cp -r gs://chrisc-public/debug-checkpoint .
Run
python3 main.py debug-checkpoint -a none
Results are finite
python3 main.py debug-checkpoint -a partial-nan
Gradient becomes NaN and intermediate values x-diff becomes > 0, but loss is finite
python3 main.py debug-checkpoint -a nan
Loss becomes NaN
The only thing that changes with these calls is the number of entirely unused arrays passed through the train_step function.