Dao-AILab/flash-attention

Feature Request: Fused Linear and Cross-Entropy Loss

imoneoi opened this issue · 1 comments

Since the latest models, such as Llama 3 and Gemma, adopt extremely large vocabularies (128-256K), the size of logits can become very large, consuming a large proportion of VRAM. For example, the following shows the VRAM size for Llama 3 8B with a batch size of 81920 tokens:

Logits size: 128256 (vocabulary size) * 81920 (batch size) * 2 bytes (bf16) = 19.57GiB
Hidden state size (checkpointed): 4096 (hidden size) * 81920 (batch size) * 32 (layers) * 2 bytes (bf16) = 20GiB

Therefore, a fused linear and cross-entropy loss operator that does not require materializing full logits may reduce VRAM consumption by half. It'd be a great addition to the FlashAttention model implementations.

I personally don't have cycles for this but we welcome PRs