Magix is a mininalist toolkit for training LLM with flexible data and model parallel.
- Training Billion-scale LLM on GPUs and TPUs.
- Familiar Huggingface model interfaces and eco-system (dataset, hub, etc.).
- Pre-defined model parallel (sharding) rules for popular models like Llama, Mistral, Gemma, etc.
- Acceleration with flash attention and operation fusion.
- Fast checkpoint save/restore with arbirary device and parallism design.
If you have ever used Huggingface Flax transformers, using magix is as simple as adding several magic functions into the common worflow.
- We start by importing necessary dependencies,
import magix
from magix.models.llama_model import FlaxLlamaForCausalLM
- We will explicitly reason about all the GPU(TPU) devices available to us. We will place the GPUs in a grid (aka mesh) using the
magix.create_device_mesh
function.
# Assume we have 4 GPUs in total; we can arrange them arbitrarily.
# Say, we arrange them into 2x2 mesh and name the first axis `data` and the second axis `model`.
# These axes will be responsible for data and model parallelisms respectively.
mesh = magix.create_device_mesh((2,2), names=('data', 'model'))
- For the next step we will load our model onto the mesh, each device will hold a part (shard) of the full model. Instead of the familiar
from_pretrained
, we will use the functionmagix.load_model_hub
function which will callfrom_pretrained
internally but also place the model correctly.
model, params = magix.load_model_hub(
FlaxLlamaForCausalLM,
'meta-llama/Llama-2-13b',
FlaxLlamaForCausalLM.partition_rules, # use the pre-defined partitioning
mesh
)
Here params
is partitioned and placed on to the mesh. As a side note, JAX will reason about model definition and parameter seperately, analogous to y = f(x|θ)
.
- For training, you will also need to do something simlar and build the optimizer states onto the mesh,
opt_state = magix.initialize_opt_state(optimizer, params, sharding_config, mesh)
- You may have seen tutorial using
jax.pmap
. For our case with both data and model parallelism, we will use the more powerfuljax.jit
,
train_step = jax.jit(
train_step, # or generate_step
donate_argnums=... # set based on the actual function input
out_shardings=(magix.item_sharding(params), magix.item_sharding(opt_state),... # set based on the actual function output
)
With all these, you are ready to start your training/inference loop.
Take a look at the complete scripts in train.py, train_lora.py and generate.py.
Assume we have 4 GPUs. Let's train mistral-7b
on UltraChat
with data and tensor parallism, dp=2
and tp=2
(mesh_shape=2 2
):
python train_lora.py \
--checkpoint_dir /absolute/path/to/checkpoint \
--model_type mistral \
--model_name mistralai/Mistral-7B-v0.1 \
--tokenizer_name mistralai/Mistral-7B-v0.1 \
--train_file HuggingFaceH4/ultrachat_200k \
--split train_sft \
--train_data_field messages \
--use_chat_template \
--batch_size 32 \
--num_epochs 1 \
--learning_rate 5e-5 \
--seed 12345 \
--mesh_shape 2 2 \
--weight_decay 0.001 \
--max_length 1024
After training, let's solve some math problems. Do generation with full tensor parallel tp=4
(mesh_shape=1 -1
):
python generate.py \
--prompts gsm8k \
--hf_data_config main \
--hf_data_split test \
--use_chat_template \
--data_field question \
--output_file generation.jsonl \
--mesh_shape 1 -1 \
--model_type mistral \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--tokenizer_name_or_path mistralai/Mistral-7B-v0.1 \
--model_config_name mistralai/Mistral-7B-v0.1 \
--batch_size 32 \
--pad_to_multiple_of 64 \
--max_length 512 \
--lora /absolute/path/to/checkpoint/EVALUATION_STEP/lora
We recommend using the jax-toolbox jax container image from nvidia. We have example Dockerfile and Singulrity Definition File.
Install appropriate jax
build, torch-cpu
and then the rest of the dependencies.
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# get torch-cpu for model conversion
pip install torch --index-url https://download.pytorch.org/whl/cpu
git clone https://github.com/luyug/magix.git
cd magix
pip install -e .