/gpt-jax

Jax/Flax rewrite of Karpathy's nanoGPT

Primary LanguagePythonMIT LicenseMIT

Jax GPT

This is a work-in-progress rewrite of Andrej Karpathy's nanoGPT in Jax/Flax.

One of the goals of this project is to try out jax.experimental.pjit. I'm curious about the performance differences for model size and distribution configurations.

Currently the GPT2 124M parameter model reaches 2.906 validation loss after training on OpenWebText with a TPU V3-32 pod slice for 150K steps (about 20 hours).

GPT2 125M loss curve

Steps to Reproduce

Prepare OpenWebText

Clone gpt-jax

git clone https://github.com/jenkspt/gpt-jax.git
cd gpt-jax

Install python dependencies

pip install -U pip
pip install tqdm
pip install numpy
pip install tiktoken
pip install datasets
pip install tensorflow

Prepare data

python data/openwebtext/prepare.py

This will generate the following files:
train_0.tfrecord, train_1.tfrecord ... train_{num_shards}
val_0.tfrecord, val_1.tfrecord ... val_{num_shards}

If you're training on a TPU, you should copy these files to a GCS bucket.

Train with TPU v3-32

  1. Create TPU v3-32
ZONE=europe-west4-a
TPU_TYPE=v3-32
VM_NAME=jax-gpt-v3-32

gcloud alpha compute tpus tpu-vm create $VM_NAME \
    --zone=$ZONE \
    --accelerator-type=$TPU_TYPE \
    --version=v2-tf-stable \
    --preemptible
  1. Clone repo and install dependencies
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE \
--worker=all --command="
git clone https://github.com/jenkspt/gpt-jax.git
cd gpt-jax

pip install -U pip
pip install tyro
pip install wandb
pip install -U tensorflow
pip install -U \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install flax
"
  1. Launch training job
EXPERIMENT=gpt2-124m/run_$(date +%Y-%m-%d_%H-%M-%S)
echo $EXPERIMENT
BRANCH=main

gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE \
--worker=all --command="

export WANDB_API_KEY=$WANDB_API_KEY
export TF_CPP_MIN_LOG_LEVEL=3       # silence annoying TF warnings
export GPT_CONFIG=config/gpt2.yaml  # this is the default GPT config for this run

cd gpt-jax
git fetch --all
git checkout $BRANCH

python3 train.py \
    --out_dir=gs://{your-bucket}/$EXPERIMENT \
    --train_pattern=gs://{your-bucket}/openwebtext/train_??.tfrecord \
    --val_pattern=gs://{your-bucket}/openwebtext/val_??.tfrecord \
    --wandb.name=$EXPERIMENT \
    --wandb.notes=''
"

Don't forget to delete the TPU instance when you're done

gcloud alpha compute tpus tpu-vm delete $VM_NAME --zone=$ZONE