a jax (flax) re-write of Andrej Karpathy NanoGPT, this repository will hold a collection of Jax/Flax new features like : Pallas kernel language for flashAttention on TPU, Data and tensor sharding with Jax on TPU
- GPT2 alike model in flax
- Mixed precision training with jmp
- Gradient accumulation with optax
- Data sharding across GPUs/TPUs using the new Jax shmap
- Loading and Saving checkpoints
- Reproduce the results on shakespear-char dataset
- TF Record reader/writer with support for data sharding across hosts
- Multi-host training
- Reproducing results on OpenWebText dataset
- Loading huggingface GPTs pre-trained models
- Fine tuning GPT-2 weights on Shakespear dataset
- Sampling
- Estimating MFU (Model flops utilization)
- Profiling training iteration,
- Optimizing Inference
- Flash attention with Pallas
- Experimenting with Jax tensor sharding
- Gradient checkpointing
- Experimenting with fine-tuning techniques
- ...
in order to run training using TPU VM, copy the generated data files into a GCP bucket
Big thanks to TPU Research Cloud for providing v2-8/v3-8/v3-32 TPU instances on Google Cloud.