
Use the latest pytorch if possible (2.3.0)

pip install torch torchvision torchaudio
pip install \
    einops \
    fire \
    hnn-utils \
    lightning \
    rich \
    schedulefree \

If your GPUs are new enough (ie. Quadro A5000, A6000, A6000 Ada, 3000 series, etc.), you can use flash attention which will decrease memory usage and increase speed

pip install flash-attn --no-build-isolation


Need to set up the logic under data/ to load the data. The model expects a tensor of shape [B, Seq_Len] of integers representing tokens. Add whatever tokens are needed to data/vocab.json. I imagine it will look like

    "[START]": 0,
    "[STOP]": 1,
    "[PAD]": 2,
    "[UNK]": 3,
    "A": 4,
    "C": 5,
    "G": 6,
    "T": 7,