Use the latest pytorch if possible (2.3.0)
pip install torch torchvision torchaudio
pip install \
einops \
fire \
hnn-utils \
lightning \
rich \
schedulefree \
wandb
If your GPUs are new enough (ie. Quadro A5000, A6000, A6000 Ada, 3000 series, etc.), you can use flash attention which will decrease memory usage and increase speed
pip install flash-attn --no-build-isolation
Need to set up the logic under data/vae_datamodule.py
to load the data. The model expects a tensor of shape [B, Seq_Len]
of integers representing tokens.
Add whatever tokens are needed to data/vocab.json
. I imagine it will look like
{
"[START]": 0,
"[STOP]": 1,
"[PAD]": 2,
"[UNK]": 3,
"A": 4,
"C": 5,
"G": 6,
"T": 7,
}