PygmalionAI/aphrodite-engine

[Usage]: What to set to get acceptable performance on Pascal GPUs? (Non-P100)

Closed this issue · 2 comments

Your current environment

The output of `python env.py`

PyTorch version: 2.2.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA TITAN X (Pascal)
Nvidia driver version: 552.22
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i7-5775C CPU @ 3.30GHz
CPU family: 6
Model: 71
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
BogoMIPS: 6599.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 128 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 1 MiB (4 instances)
L3 cache: 6 MiB (1 instance)
L4 cache: 512 MiB (4 instances)
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Full generic retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds: Unknown: Dependent on hypervisor status
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT Host state unknown
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.0
[pip3] triton==2.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.2.0 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypiROCM Version: Could not collect
Aphrodite Version: 0.5.1
Aphrodite Build Flags:
CUDA Archs: Not Set; ROCm: Disabled

How would you like to use Aphrodite?

I got my hands on a GTX Titan X Pascal 12GB GPU before my shipment of a few Tesla P40 and P100 are arriving.
The Titan and the P40 uses the GP102 GPU and not the GP100 on the P100, they do not have acceptable FP16 performance.

The problem is that aphrodite seems to utilize only FP16 computations on the GPU while oobabooga and llama.cpp for example can be set to utilize the mmq kernels to utilize FP32 which is miles faster on the Pascal GP102 GPUs. Is there a feature in aphrodite that allows this as well? Being able to run compute at FP32 would make all the older Pascal GPUs usable instead of it is right now.

Currently the GTX Titan X running Exl2 or GPTQ Llama 3 8B models runs them at 1.0t/s which is ridiculously slow and the GPU is barely being utilized since the GP102 GPU does not have a lot of FP16 cores. I know that for Exl2 it isn't possible to run FP32 compute?

Tesla P100 specs with heaps of FP16 performance: https://www.techpowerup.com/gpu-specs/tesla-p100-pcie-16-gb.c2888

Tesla P40 specs with almost no FP16: https://www.techpowerup.com/gpu-specs/tesla-p40.c2878

Also I cannot seem to get GGUF to load on this Titan X Pascal getting this error, any help would be appreciated.

python -m aphrodite.endpoints.openai.api_server \
--model /home/owen/models/Meta-Llama-3-8B-Instruct-Dolfin-v0.1-Q4_K_M.gguf \
--gpu-memory-utilization 0.95 --max-model-len 2048 --port 8000 --enforce-eager \
--served-model-name Meta-Llama-3-8B-Instruct-Dolfin --worker-use-ray
WARNING:  gguf quantization is not fully optimized yet. The speed can be slower than non-quantized models.
2024-05-05 23:53:02,780 INFO worker.py:1749 -- Started a local Ray instance.
INFO:     Initializing the Aphrodite Engine (v0.5.1) with the following config:
INFO:     Model = '/home/owen/models/Meta-Llama-3-8B-Instruct-Dolfin-v0.1-Q4_K_M.gguf'
INFO:     DataType = torch.float16
INFO:     Model Load Format = auto
INFO:     Number of GPUs = 1
INFO:     Disable Custom All-Reduce = False
INFO:     Quantization Format = gguf
INFO:     Context Length = 2048
INFO:     Enforce Eager Mode = True
INFO:     KV Cache Data Type = auto
INFO:     KV Cache Params Path = None
INFO:     Device = cuda
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/endpoints/openai/api_server.py", line 563, in <module>
    engine = AsyncAphrodite.from_engine_args(engine_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/engine/async_aphrodite.py", line 676, in from_engine_args
    engine = cls(parallel_config.worker_use_ray,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/engine/async_aphrodite.py", line 341, in __init__
    self.engine = self._init_engine(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/engine/async_aphrodite.py", line 410, in _init_engine
    return engine_class(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/engine/aphrodite_engine.py", line 104, in __init__
    self._init_tokenizer()
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/engine/aphrodite_engine.py", line 168, in _init_tokenizer
    self.tokenizer: TokenizerGroup = TokenizerGroup(
                                     ^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/transformers_utils/tokenizer.py", line 157, in __init__
    self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/transformers_utils/tokenizer.py", line 78, in get_tokenizer
    return convert_gguf_to_tokenizer(tokenizer_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/aphrodite/transformers_utils/tokenizer.py", line 63, in convert_gguf_to_tokenizer
    tokenizer = LlamaTokenizer(**tokenizer_args)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/transformers/models/llama/tokenization_llama.py", line 169, in __init__
    self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/transformers/models/llama/tokenization_llama.py", line 196, in get_spm_processor
    tokenizer.Load(self.vocab_file)
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/sentencepiece/__init__.py", line 961, in Load
    return self.LoadFromFile(model_file)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/owen/miniconda3/envs/aphro/lib/python3.11/site-packages/sentencepiece/__init__.py", line 316, in LoadFromFile
    return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Internal: unk is not defined.

Please refer to https://github.com/PygmalionAI/aphrodite-engine/wiki/8.-Quantization#gguf-dev-branch about how to use llama 3 gguf models (the tokenizer of llama 3 isn't a LlamaTokenizer)
The performance issue on the other hand can't be solved easily, as for most quants there isn't a fp32 kernel.

Please refer to https://github.com/PygmalionAI/aphrodite-engine/wiki/8.-Quantization#gguf-dev-branch about how to use llama 3 gguf models (the tokenizer of llama 3 isn't a LlamaTokenizer) The performance issue on the other hand can't be solved easily, as for most quants there isn't a fp32 kernel.

Thanks for replying. Definitely missed that part because of banging my head trying different things to make this thing not go so slow. Does the GGUF kernel run on FP32? It does exist on llama.cpp but as I understand aphrodite converts GGUF to safetensors first anyways?

EDIT: Ok so I figured it out finally, for anyone trying to run aphrodite on Pascal non-GP100 GPUs, here is how:

  1. Create a miniconda environment: conda create -n aphrodite python=3.11
  2. Install CUDA: conda install -y -c "nvidia/label/cuda-12.1.1" cuda
  3. Install pytorch: pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121
  4. Install aphrodite: pip install -e .
  5. Run aphrodite with either exl2 (slow as balls) or GGUF (fast)

On Llama 3 8B Q4KM, I can get about 40t/s on one request on a GTX Titan X Pascal 12GB, but increases of parallel requests just tanks the performance immediately after more than 4 parallel requests. It does seem like the GPU core is fully utilized at 100% and is choking on doing matrix multiplications without the help of Tensor cores. Its ok but definitely not usable for multiple parallel requests.

Completed 16 prompts and produced 967 tokens in 30.533 seconds.
Average TPS across all 1 threads: 31.7 - Individual Threads: Min TPS: 31.7, Max TPS: 31.7

Completed 16 prompts and produced 4361 tokens in 61.385 seconds.
Average TPS across all 4 threads: 71.0 - Individual Threads: Min TPS: 15.5, Max TPS: 18.9

Completed 16 prompts and produced 9541 tokens in 484.179 seconds.
Average TPS across all 8 threads: 19.7 - Individual Threads: Min TPS: 2.2, Max TPS: 2.7

Completed 16 prompts and produced 14610 tokens in 544.334 seconds.
Average TPS across all 12 threads: 26.8 - Individual Threads: Min TPS: 1.9, Max TPS: 2.7