Run LLaMA inference on Apple Silicon GPUs.
As you can see, unlike other LLMs, LLaMA is not biased in any way 😄
1. Clone this repo
git clone https://github.com/jankais3r/LLaMA_MPS
2. Download the model weights and put them into a folder called models
(e.g., LLaMA_MPS/models/7B
)
3. Install Python dependencies
pip3 install virtualenv
python3 -m venv env
source env/bin/activate
pip3 install -r requirements.txt
pip3 install -e .
4. (Optional) Reshard the model weights (13B/30B/65B)
Since we are running the inference on a single GPU, we need to merge the larger models' weights into a single file.
mv models/13B models/13B_orig
mkdir models/13B
python3 reshard.py 1 models/13B_orig models/13B
5. Run the inference
python3 chat.py --ckpt_dir models/13B --tokenizer_path models/tokenizer.model --max_batch_size=8 --max_seq_len=256
Model | Starting memory during inference | Peak memory during checkpoint conversion | Peak memory during resharding |
---|---|---|---|
7B | 16 GB | 14 GB | N/A |
13B | 32 GB | 37 GB | 45 GB |
30B | 66 GB | 76 GB | 125 GB |
65B | ?? GB | ?? GB | ?? GB |
Min specs per model (slow due to swapping):
- 7B - 16 GB RAM
- 13B - 32 GB RAM
- 30B - 64 GB RAM
- 65B - needs testing
Recommended specs per model:
- 7B - 24 GB RAM
- 13B - 48 GB RAM
- 30B - 96 GB RAM
- 65B - needs testing
- max_batch_size
If you have spare memory (e.g., when running the 13B model on a 64 GB Mac), you can increase the batch size by using the --max_batch_size=32
argument. Default value is 1
.
- max_seq_len
To increase/decrease the length of the generated text, use the --max_seq_len=256
argument. Default value is 512
.
- use_repetition_penalty
The example script penalizes the model for generating a repetitive content. This should lead to higher quality output, but it slightly slows down the inference. Run the script with --use_repetition_penalty=False
argument to disable the penalty algorithm.
The best alternative to LLaMA_MPS for Apple Silicon users is llama.cpp, which is a C/C++ re-implementation that runs the inference purely on the CPU part of the SoC. Because compiled C code is so much faster than Python, it can actually beat this MPS implementation in speed, however at the cost of much worse power and heat efficiency.
See the below comparison when deciding which implementation better fits your use case.
Implementation | Total run time - 256 tokens | Tokens/s | Peak memory use | Peak SoC temperature | Peak SoC Power consumption | Tokens per 1 Wh |
---|---|---|---|---|---|---|
LLAMA_MPS (13B fp16) | 75 s | 3.41 | 30 GB | 79 °C | 10 W | 1,228.80 |
llama.cpp (13B fp16) | 70 s | 3.66 | 25 GB | 106 °C | 35 W | 376.16 |
- facebookresearch (original code)
- markasoftware (cpu optimizations)
- remixer-dec (mps optimizations)
- venuatu (continuous token printing / loading optimizations)
- benob (reshard script)
- tloen (repetition penalty)