/jax-transformers

Transformers implemented using pure JAX in a stax like manner.

Primary LanguagePython

Transformers implemented on pure JAX (WIP)

This is an implementation of classic transformers like BERT, GPT and t5 using pure jax in stax manner.