huggingface/transformers

[Efficiency] The llama model with flash attention is slower than that without flash attention

KexinFeng opened this issue · 7 comments

System Info

The test ran with this fix applied: #26984

- `transformers` version: 4.34.0
- Platform: Linux-5.15.0-1045-aws-x86_64-with-glibc2.31
- Python version: 3.9.18
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.4.0
- Accelerate version: 0.23.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The model loading:

def get_model_tokenizer(model_id, flash_attn=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_id_or_path = "huggyllama/llama-7b"
    model = AutoModelForCausalLM.from_pretrained(
        model_id_or_path, device_map='auto' if device.type == 'cuda' else 'cpu',
        use_flash_attention_2=flash_attn)
    lm_block = HuggingfaceBlock(model)
    tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
                                              padding_side='left')
    tokenizer.pad_token = "[PAD]"

    return lm_block, tokenizer

Input_length = 760
batch_size = 13
Max_gen_token = [300, 100, 50, 20]

When `flash_attn==True':

token_latency: [18.3 ms/token, 20.7 ms/token, 26.4 ms/token , 44.1 ms/token ]

When 'flash_attn' == False':

token_latency: [14.1 ms/token, 17.8 ms/token, 24.3 ms/token , 44.2 ms/token ]

Expected behavior

Flash attention should accelerate the inference.

Hi @KexinFeng
Thanks for the issue, usually the speedup is quite considerable for a large sequence length. Can you try out your experiment with for example seq_len=2048? Also make sure to use a batch size that is divisble by 2

@younesbelkada Thanks for pointing out the sequence length. Indeed, at seq_len=3500, the flash_attention gains speed up. However, it is not significant compared to non-flash attention.

Input_length = 3500
batch_size = 4
Max_gen_token = [300, 100, 50, 20]

Corresponding to each max_gen_token:

flash_attn=True

token_latency = 33.9 ms/token, 39.7 ms/token, 49.3 ms/token, 78.8 ms/token 

flash_attn = False

token_latency = 28.8 ms/token, 39.9 ms/token, 57.3 ms/token, 110 ms/token 

I thought the expected behaviour should be that the flash_attention should be purely faster than non-flash attention. What factor contributed the overhead to the flash_attention compared to non-flash attention?

From the benchmark above, it seems that as gen_token gets longer, the flash_attention is slower. This means that this overhead contributed to the flash_attention only is induced at every decoding step. So the speed up gained at the prefill step is gradually overridden by such overhead as decoding steps proceed.

If you are passing the attention mask to the model, I think the pad and unpad operation add a non negligeable overhead

@ArthurZucker Yes, indeed, I fed the attention mask into the model, with a lot of 0 entries (corresponding to the PAD token). Thanks for this insight. But is there any plan of removing this overhead? It seems to me that flash_attention algorithm in principle doesn't necesarily require the pad and unpad operation. Currently, it looks that the advantage of flash_attention over non flash one is not clear.

Hi @KexinFeng
As stated by @ArthurZucker adding padd tokens in the sequence length adds a considerable overhead in FA modules. The expected speedups and best scenarios on when to use FA-2 are clearly stated in this section of the docs: https://huggingface.co/docs/transformers/perf_infer_gpu_one#expected-speedups

@younesbelkada Thank you for pointing this document to me! Indeed, the issue I brought up here has been documented there. What's more, the document also shows the data of how the speedup depends on prompt max length, which is also very helpful.

However regarding the solution proposed in the document,

To overcome this, one should use Flash Attention without padding tokens in the sequence for training (e.g., by packing a dataset, i.e., concatenating sequences until reaching the maximum sequence length. An example is provided here.

it doesn't seem to be applicable on model inference and serving scenario, which is where this issue originates. Especially with dynamically batching inference, this packing of dataset doesn't work. It seems to me that padding is unavoidable in the inference scenarios. A possible way to avoid it is to switch the flash attention kernal to something like var_len_single_query_attention (already exists in the flash attention repo), where the input is flattened into 1D tensor.

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.