/NanoGpt-JAX

Jax implementation of the nanoGpt by Andrej Karpathy

Primary LanguageJupyter Notebook

NanoGpt in JAX

This is a JAX version of the NanoGPT example from Andrej Karpathy's tutorial Let's build GPT from scratch, in code, spelled out.

PyTorch version of the notebook is at https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing PyTorch code is at https://github.com/karpathy/nanoGPT

To learn more about JAX

This note book also uses the following neural network libraries built on top of JAX:

Flax: a Python neural network library.

Optax: Commonly used optimizers.