JetStream Engine implementation in PyTorch
- Ssh to Cloud TPU VM (using v5e-8 TPU VM) a. Create a Cloud TPU VM if you haven’t
- Download jetstream-pytorch github repo
- Clone repo and install dependencies
- Download and convert weights
- Run checkpoint converter (quantizer)
- Local run
- Run the server
- Run benchmarks
- Typical Errors
gcloud compute config-ssh
gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone"
Follow the steps in
git clone https://github.com/google/jetstream-pytorch.git
git checkout jetstream-v0.2.2
(optional) Create a virtual env using venv
or conda
and activate it.
cd jetstream-pytorch
source install_everything.sh
Following instructions here:
- Llama-2: https://github.com/meta-llama/llama#download
- Llama-3: https://github.com/meta-llama/llama3/#download
After you have downloaded the weights, it will also download a tokenizer.model
file that is
the tokenizer that we will use.
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
Need to manually modify the config.json
in the checkpoint folder to make it a valid JSON file. (Replace '
with "
, remove the excessive ,
after the last item in the JSON object)
Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export model_name="llama-3" # or "llama-2", "gemma", "mixtral"
export quantize_weights=True # Whether to quantize weights
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type
Set tokenizer path
export tokenizer_path=tokenizer model file path
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
Here is an example to run the server with llama2 7B config.
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
Now you can fire gRPC to it.
Optional flags:
-
--shard_on_batch=1
This makes the model to shard on the batch dimension. I.e. this runs in data parallel mode instead of model parallel. This will ignore the sharding config. This is recommended for Gemma 2B model, because Gemma 2B is small enough to fit on a single TPU chip. -
--sharding_config=<path>
This makes use of alternative sharding config instead of the ones in default_shardings directory.
Below are steps run server with ray:
- Ssh to Cloud Multiple Host TPU VM (v5e-16 TPU VM)
- Step 2 to step 5 in Outline
- Setup ray cluster
- Run server with ray
Login host 0 VM, start ray head with below command:
ray start --head
Login other host VMs, start ray head with below command:
ray start --address='$ip:$port'
Note: Get address ip and port information from ray head.
Here is an example to run the server with ray for llama2 7B model:
python run_server_with_ray.py --tpu_chips=16 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
Start the server and then go to the deps/JetStream folder (downloaded during install_everything.sh
)
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs --warmup-first=True
Please look at deps/JetStream/benchmarks/README.md
for more information.
Fix:
- Uninstall jax and jaxlib dependencies
- Reinstall using `source install_everything.sh
Fix:
- Use smaller batch size
- Use quantization