/optimized-LLM

This is a try to create the most optimized llm architecture

Primary LanguagePythonApache License 2.0Apache-2.0

Optimized LLM

The goal of this repository is to use the most optimized technic to train from scratch a LLM and to be the fastest at inference time.

I call this llm anemone, but I'm not satisfied with this name. why not MoM for mixture of mixture ?

Installation

Using a virtual environment is recommended.

pip install --upgrade torch --index-url https://download.pytorch.org/whl/cu121 
pip install -r requirements.txt

TODO

Test

Model Without mixture of depth

To test the first model, that has 1.58 bits linear layer, jamba base architecture and moah, you can clone this repo at this commit:

and run the following command:

python infer.py

Model With mixture of depth

To test the second model, that has 1.58 bits linear layer, jamba base architecture, moah and mod, you can clone this repo at this commit

You can start the training process by running the following command:

python train.py

and compare the results with the first model.

You can also run the following command to test the inference:

python infer.py

MoMv2-bf16

This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.

This model doesn't contain any 1.58 bits linear layer.

The difference between this model and the previous one is the use of a softmax function to weight the token for the mod and this break the causality and that's maybe why the model output no sense text.

You can also run the following command for this commit to test the inference:

python infer.py

MoMv3

This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.

All mamba, routers, moe, mlp are 1.58 bits linear layer. The linear layers in the attention mechanism are not 1.58 bits linear layers.

You can also run the following command to test the inference and change MoMv3 by MoMv3-mixed-precision in the file:

python infer.py --prompt "This is the story of"

To run the full 1.58bits model, you can run the following command:

python infer.py --prompt "This is the story of" --model "MoMv3-1.58bits"

To run the model with mamba and attention in bf16 and the rest in 1.58bits, you can run the following command:

python infer.py --prompt "This is the story of" --model "MoMv3-M-A-mixed-precision"

To run the full bf16 model, you can run the following command:

python infer.py --prompt "This is the story of" --model "MoMv3-bf16"

MoMv4

This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.

All mamba, routers, moe, mlp are 1.58 bits linear layer. The linear layers in the attention mechanism are in bf16 precision.

The total number of parameters is 1.7% in bf16 and the rest in 1.58bits.

The total active parameters is in a first estimation 87M parameters over 1B parameters.

Each mlp layer has 12.4M parameters each token can pass through 7 mlp layers and 7 mlp expert layer which is 2*7mlp layer. For mlp, the number of parameters is 12.4M * 21 = 261.1M parameters. We add the mamba and attention parameters that are near 107M parameters. And only 1/4 of the tokens pass through a block. So the total number of active parameters is 368.1/4 = 87M parameters.

To test the inference, you can run the following command:

python infer.py --prompt "This is the story of" --model "MoMv4-1.58bits"

and

python infer.py --prompt "This is the story of" --model "MoMv4-bf16"

MoMv5

This model is a mixture of mixture (Mod, MoD, MoAH) with jamba base architecture.

All mamba, routers, moe, mlp are in bf16 precision.

To test the inference, you can run the following command:

python infer.py --prompt "This is the story of" --model "MoMv5-bf16"

and

python eval.py --model "MoMv5-bf16" --max_seq_length 512

perplexity: 15.02

Evaluation

To evaluate the model, you can run the following command:

python eval.py --model "MoMv4-1.58bits" --max_seq_length 512

which has a loss of 2.62 and a perplexity of 13.77.

You can also run the evaluation for the full bf16 model:

python eval.py --model "MoMv4-bf16" --max_seq_length 512

which has a loss of 2.53 and a perplexity of 12.59.

The bf16 version is a bit better

Conclusion

We can see (here) that the baseline (MoMv3-bf16) has a similar loss curve as the attention and mamba in bf16 and the rest in 1.58bits (MoMv3-M-A-mixed-precision) and the attention in bf16 and the rest in 1.58bits (MoMv3-mixed-precision).

Furthermore, the training is faster when attention is not at 1.58bits, and it takes lesser vram too.

To train a model with long context and a lot of parameters, for a fast and low memory inference, I found that using the jamba architecture and with all linear layer in 1.58bits excepted for the attention mechanism's layers can be a godd strategy. With only 1.7% of parameters in bf16, the model can fit in cheap gpu during inference. Moreover, using all the mixture (moeh, moe and mod) you can train the model faster with only a few active parameters.

Contributing

Contributions are welcome.

Please open a pull request with the proposed changes.

License

Apache License 2.0