nanoXLSTM is a minimal codebase for playing around with language models based on the xLSTM (extended Long Short-Term Memory) architecture from the awesome research paper: xLSTM: Extended Long Short-Term Memory and heavily inspired by Andrej Karpathy's nanoGPT.
**Note: Work in progress!!! I am working on improving the generated text.
No lofty goals here - just a simple codebase for tinkering with this innovative xLSTM technology!
Contributions are more than welcome as I continue exploring this exciting research direction.
pip install torch numpy transformers datasets tiktoken wandb tqdm
!python data/shakespeare_char/prepare.py
python train.py config/train_shakespeare_char.py
python sample.py --out_dir=out-shakespeare-char
- Run hyperparameter sweep
- Import OneCycleLR: The OneCycleLR scheduler is imported from
torch.optim.lr_scheduler
. -
sLSTM
class: Thef_bias
anddropout
are added to thesLSTM
class. -
mLSTM
class: Thef_bias
anddropout
are added to themLSTM
class. -
xLSTMBlock
class: ThexLSTMBlock
class is implemented with a configurable ratio ofsLSTM
andmLSTM
blocks, and layer normalization is applied. -
GPT
class: ThexLSTM_blocks
are used in theGPT
class instead of separatesLSTM
andmLSTM
blocks. -
configure_optimizers
method: Theconfigure_optimizers
method in theGPT
class is updated to use AdamW optimizer and OneCycleLR scheduler.
20/05/24
- Initialize the forget gate bias (self.f_bias) with values between 3 and 6 instead of ones. This helps the forget gate to be effective from the beginning of training.
- Introduce a stabilization technique to avoid overflow due to the exponential function. You can use the max function to compute a stabilization factor and subtract it from the input gate and forget gate activations before applying the exponential function.
- Import statement: The OneCycleLR scheduler is imported.
- Optimizer and scheduler initialization: The optimizer and scheduler are obtained from the
configure_optimizers
method of theGPT
class. - Loading optimizer and scheduler state: The optimizer and scheduler states are loaded from the checkpoint when resuming training.
- Saving scheduler state: The scheduler state is included in the checkpoint dictionary.
- Stepping the scheduler: The
scheduler.step()
is called after each optimizer step. - Logging learning rate and MFU: The learning rate and MFU are logged using
wandb
(ifwandb_log
is enabled). -
estimate_loss
function: Theestimate_loss
function is updated to use thectx
context manager. - Training loop: The training loop is updated to use
scaler.scale(loss).backward()
andscaler.step(optimizer)
for gradient scaling when training in fp16.