A Simple Transformer In Jax

This is the repository for my blog post on JITx.

The transformer here is small, lacks dropout and only has a single MHA block. It is by no means meant to be production grade code. It's only for learning purposes.