Benchmark multi-query attention in HF transformers
harm-devries opened this issue · 1 comments
For full integration into HF, the best speedup is about 24% for multi query attention:
-------------------- attention_type == AttentionType.MULTI_QUERY_1---------------------
{'get_test_batch': 4.172325134277344e-05, 'generate_text_batch': 15.190143346786499, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'Tesla V100-PCIE-16GB-LS'}
-------------------- attention_type == AttentionType.MULTI_HEAD---------------------
{'get_test_batch': 5.459785461425781e-05, 'generate_text_batch': 19.78107237815857, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'Tesla V100-PCIE-16GB-LS'}
We've also tried to isolate the attention layer to investigate the speed-ups for the best possible implementation of multi-query attention. To this end, we experimented with a simplified attention layer without softmax and normalization and tried two variants: multi-query
and multi-query_1
(see details about the differences here). We found speed-ups to 2x for the multi-query_1
variant. See figure below.
Further engineering is necessary to investigate whether we can get better speed-ups. For example by
- combining normalisation and
q
andk
multiplication - split weights matrix for
q
, (k
andv
)
We also need to check if we used profiler in the right way. We observed that the cpu time is quite big and speedup is smaller for larger sequence length (and we're not sure why yet).
More details https://github.com/bigcode-project/bigcode-analysis/tree/multi_query_experiments/multi_query_experiments