/flashT5

A fast implementation of T5/UL2 in PyTorch using Flash Attention

Primary LanguagePythonApache License 2.0Apache-2.0

FAT5 - A fast implementation of T5/UL2 with Flash Attention

Warning

This repository is still under development and may still contains various bugs. Please refer to the roadmap part of this README for known issues. A technical report is currently being written to detail our approach.

FAT5 (for Flash Attention T5) is an implementation of T5 in PyTorch with an UL2 objective optimized for GPGPU for both training and inference. It uses an experimental feature for using Flash Attention (v2) with relative position encoding biases that allow to train or finetune the model on longer sequence lengths than the original T5. It also has support for other positional embeddings such as RoPE, ALiBi or FIRE.

Motivation

While a lot of effort has been focused on optimizing decoder-only models, in many practical applications older architectures remains useful. We focus on T5 by Raffel et al. (2020), an encoder-decoder architecture exhibiting very decent performances for instruction tuning or even sometimes outperforming much larger models when finetuned. Moreover it’s a natural architecture while considering distillation of much larger models.

A critical limitation of this architecture is the length of the sequence that these models can deal with due to the quadratic size in memory. While this quadratic term cannot be removed without considering other form of attention (like for LongT5), it can still be alleviated to accomodate longer sequence lengths.

Our work

We used the nanoT5 implementation (Nawrot, 2023) as the base for our work.

We worked on optimizing the core component of the model, which is the attention part. We used the Flash Attention (v2) by Dao (2023) that optimize both the memory usage and the efficient use of Tensor Cores.

While the original implementation does not support attention biases, we added this component in this PR. The implementation support full attention biases (batch_size, num_heads, seqlen_q, seqlen_k) or partial attention biases (1, 1, seqlen_q, seqlen_k). The latter allow us to remove the full size attention mask in the implementation of T5, while the causality can be enforced by masking in the kernel itself, thus reducing the memory by a factor of the size of batch for this tensor. This allows to fit larger batch sizes and thus increasing throughput during training.

FAT5 animation

Other parts of the architecture where optimized using ad-hoc Triton kernels for the cross-entropy (and z-loss) and layernorm. We also provide a Triton implementation of Flash Attention 2 supporting attention biases for those who do not like to recompile a custom patch for the flash attention.

For pretext tasks during pre-training, we use the UL2 mixture of denoisers by Tay et Dehghani (2022) with the following 7 tasks:

denoiser_list=[
{"mu": 3.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
{"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[R]"},
{"mu": 4.0, "r": 0.0, "max_spans": 1, "prefix": "[S]"},
{"mu": 3.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"},
{"mu": 8.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
{"mu": 64.0, "r": 0.15, "max_spans": max_token_length, "prefix": "[X]"},
{"mu": 64.0, "r": 0.5, "max_spans": max_token_length, "prefix": "[X]"}],
denoiser_proportions=[0.165, 0.165, 0.34, 0.0825, 0.0825, 0.0825, 0.0825]

where mu: the span size, r: the % of masking in the span and prefix: the type of the pretext task (the meaning of the letters [R], [S] and [X] is described here).

As there was no implementation available in PyTorch, we added one and adapted a dynamic batching mechanism to reduce padding in the model.

Benchmarks

The benchmarks were made on a A100 80G by comparing to the original implementation of T5 v1.1 available on Hugging Face. The sequence length is the same for both the encoder and the decoder. Different sequence lengths for both parts are possible and even recommended depending on the application.

We see that below that for a sequence length below 256, torch.compile does a pretty good job in optimizing the model while the Flash Attention start to pick up speed at 512 length and above. Note that the orignal model cannot accommodate larger than 512 sequence length despite using a 80G GPU!

We implemented an interface to use both Flash Attention 2 and torch.compile. You can find a torch compliable interface to Flash Attention 2 here.

We can see a clear improvement in memory usage in our implementation for larger batch sizes (no value means OOM):

Install

Training the model requires a custom installation of Flash Attention 2 using this patch. Another possibility is to rely on the triton version of Flash Attention 2.

Pretraining

We tested and trained the model on A100. It may or may not work with other GPUs. The training script is provided here. It assumes that the dataset is already pretokenized and uses Hugging Face trainer.

python train_flash_t5.py config/flash-t5-base.yaml

It supports accelerate for out of the box distributed training.

Finetuning

Warning

We are currently benchmarking our pre-trained models in French (see next section) to analyze the quality of our models and also whether our head implementations are correct. So this work is still WIP.

For the classic T5, four different heads are available on Hugging Face: T5ForConditionalGeneration, T5ForSequenceClassification T5ForTokenClassification and T5ForQuestionAnswering. You can find the adaptation of the first head in this file and that of the last three in this file.

What we can say/observe at this stage is:

  • We tested the FlashT5ForConditionalGeneration head on a text summarization task, in particular on the dataset orange_sum. The outputs of this dataset are 32 tokens. That's why for this line we set max_length = 32. You'll need to set this value manually if you want to generate a different length.
    For this head we've based ourselves on the nanoT5 implementation and not the Hugging Face one, as the latter is much faster (1 epoch of FlashT5ForConditionalGeneration takes us 6 min on FAT5-base vs. 3h30 on MT5-small).
    The hyperparameters recommended in the T5 documentation (i.e. lr = 1e-4 or 3e-4) don't seem to be the most suitable for this task for the FAT5 (= we match the results of Barthez, who introduced the orange_sum dataset, in 3 epochs against 30 but then reach a plateau). We need to carry out a search for hyperparameters.
    For all the other tasks described below, a lr of 1e-4 gives the best results in the experiments we have carried out.
  • For the FlashT5ForTokenClassification, we based ourselves on the implementation available on Hugging Face. This uses only the encoder. Thus, the number of parameters finetuned for this task are halved, and we obtain models with 67.1M parameters for the small version, 138M for the base version and 436M for the large version. This is something to bear in mind when benchmarking.
  • For the ForSequenceClassification, the implementation available in Hugging Face is based on the encoder and decoder. This seems to us to be sub-optimal, so we've developed an encoder-only head. Thus, the number of parameters finetuned for this task are halved, and we obtain models with 67.4M parameters for the small version, 138M for the base version and 436M for the large version. This is something to bear in mind when benchmarking.
  • For the T5ForQuestionAnswering, the implementation available in Hugging Face is based on the encoder and decoder. This seems to us to be sub-optimal, so we've developed an encoder-only head.

Warning

Once your model has been finetuned, if you want to upload the weights to the Hugging Face Hub using the push_to_hub function, the latter won't load all the files you need to be able to reuse the model later. You'll have to perform a second upload yourself, where you'll load the missing files (these files are listed in the PR below). This is due to a bug in the transformers library. It has been reported and you can follow its progress in this PR.

Applications

To French

We've used the codes of this repository to pretrain three FAT5-UL2 in French, a small version (147M parameters), a base version (305M parameters) and a large version (973M parameters). The weights will soon be released. Models are pre-trained on the French part of the CulturaX corpus by Nguyen et al. (2023), i.e. 1,258 GB of text. The models were run on a single A100 80G for 11 days for the base version and two A100 80G 25 days for the large version (100 000 steps in both cases).

To English

Our contribution focuses on French, with the pre-training and finetuning of models for comparison against French benchmarks. For English, we can't afford to do the same kind of work. Nevertheless, to ensure that it can be used by English speakers, we have adapted the weights of the various versions of the FLANT-T5 by Won Chung, Hou, Longpre et al. (2022) to our method. We hope that in this way, users will be able to efficiently continue pre-training one of these versions to adapt it to more recent data or specialize it on a specific domain, for example. All weights can be found in this Hugging Face collection. To use one of the models, simply do the command:

from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("CATIE-AQ/FAT5-small-flan-en", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

To any language

If the Flan-T5 doesn't suit you and you'd like to use another T5 (such as T5, T5v1.1, etc.) or in another language (mT5 or a T5 trained in a specific language), you can find the code we used for the Flan-T5 here.

Integration into the transformers library

In the code snippet above, you can see that we're forced to use a trust_remote_code=True to load the model, as it's not natively available in the transformers library.
We're working on adding FAT5 to the transformers library for simpler use (which would, for example, solve the push_to_hub problem listed above).
One sticking point now is that Hugging Face relies on the official flash-attention library to get this type of model running under the hood in transformers. So until our PRs in the Flash Attention library have been merged, a port to transformers is blocked. A second non-blocking but time-consuming point is to come to an agreement with the transformers library maintainers in order to properly format our code (the fact that we've developed new heads, that one modifies the existing one, the code to display the documentation on the Hugging Face site, etc.).

Roadmap

Here is several following up works that we would like to make:

  • Support flash decoding for inference.

  • Experiment with finetuning or distillation with long sequences.

  • We are also trying to revisit the encoder-decoder architecture using subquadratic operators to replace the attention. Stay tuned for more information about this.

License

Apache-2.0 license

Ackowledgment

We use the following repos and thanks the authors for this:

  • nanoT5 for the simple implementation and the optimizer.
  • Flash attention for the groundbreaking algorithm for computing attention.
  • Hugging Face for their excellent library.
  • FlagAttention for the implementation of FA2 in Triton.
  • Unsloth for the simple Triton kernels of the cross-entropy and layernorm that we adapted to our usage.

This work was support by the Vaniila platform.

Vaniila Logo