This is a JAX version of the NanoGPT example from Andrej Karpathy's tutorial Let's build GPT from scratch, in code, spelled out.
PyTorch version of the notebook is at https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing PyTorch code is at https://github.com/karpathy/nanoGPT
This note book also uses the following neural network libraries built on top of JAX: