/grok-jax

Reproduction of the grokking paper using jax.

Primary LanguagePythonMIT LicenseMIT

Repro of grokking paper

Installation

First, install pytorch and jax. Clone this repo, then

pip install -e .