bigcode-project/transformers

Benchmark multi-query attention in HF transformers

harm-devries opened this issue · 1 comments

Benchmark multi-query attention in HF transformers

For full integration into HF, the best speedup is about 24% for multi query attention:

-------------------- attention_type == AttentionType.MULTI_QUERY_1---------------------
{'get_test_batch': 4.172325134277344e-05, 'generate_text_batch': 15.190143346786499, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'Tesla V100-PCIE-16GB-LS'}
-------------------- attention_type == AttentionType.MULTI_HEAD---------------------
{'get_test_batch': 5.459785461425781e-05, 'generate_text_batch': 19.78107237815857, 'input_batch_size': 8, 'input_batch_length': 16, 'max_gen_length': 1024, 'num_beams': 1, 'do_sample': False, 'pad_token_id': 50256, 'dtype': torch.int64, 'device': device(type='cuda'), 'cuda_device_name': 'Tesla V100-PCIE-16GB-LS'}

We've also tried to isolate the attention layer to investigate the speed-ups for the best possible implementation of multi-query attention. To this end, we experimented with a simplified attention layer without softmax and normalization and tried two variants: multi-query and multi-query_1 (see details about the differences here). We found speed-ups to 2x for the multi-query_1 variant. See figure below.

Image

Further engineering is necessary to investigate whether we can get better speed-ups. For example by

  • combining normalisation and q and k multiplication
  • split weights matrix for q, (k and v)

We also need to check if we used profiler in the right way. We observed that the cpu time is quite big and speedup is smaller for larger sequence length (and we're not sure why yet).

More details https://github.com/bigcode-project/bigcode-analysis/tree/multi_query_experiments/multi_query_experiments