/flash-nanoGPT

Jax/Flax re-write of @karpathy 🐐 NanoGPT using some of the common Jax libraries/features (shmap, pallas, jmp, optax, orbax)

Primary LanguagePython

flash-nanoGPT (Under development)

a jax (flax) re-write of Andrej Karpathy NanoGPT, this repository will hold a collection of Jax/Flax new features like : Pallas kernel language for flashAttention on TPU, Data and tensor sharding with Jax on TPU

Todos

  • GPT2 alike model in flax
  • Mixed precision training with jmp
  • Gradient accumulation with optax
  • Data sharding across GPUs/TPUs using the new Jax shmap
  • Loading and Saving checkpoints
  • Reproduce the results on shakespear-char dataset
  • TF Record reader/writer with support for data sharding across hosts
  • Multi-host training
  • Reproducing results on OpenWebText dataset
  • Loading huggingface GPTs pre-trained models
  • Fine tuning GPT-2 weights on Shakespear dataset
  • Sampling
  • Estimating MFU (Model flops utilization)
  • Profiling training iteration,
  • Optimizing Inference
  • Flash attention with Pallas

Future work

  • Experimenting with Jax tensor sharding
  • Gradient checkpointing
  • Experimenting with fine-tuning techniques
  • ...

data generation

in order to run training using TPU VM, copy the generated data files into a GCP bucket

Acknowledgement

Big thanks to TPU Research Cloud for providing v2-8/v3-8/v3-32 TPU instances on Google Cloud.

References

  • Original nanoGPT repositories [1]
  • jax based nanoGPT repositories [1] [2]
  • Nvidia mixed precision training [1]
  • Google Cloud documentation [1]