huggingface/transformers

I have a question in the source code called modeling_llama.py

park1200656 opened this issue · 6 comments

System Info

@ArthurZucker @gante

path : "src/transformers/models/llama/modeling_llama.py"

Line 85 and 232 of this code contains float32 as a constant.
I think, it looks like a bug. Or is there another reason?

Thanks.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

class LlamaRMSNorm(nn.Module):
def init(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().init()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
    input_dtype = hidden_states.dtype
    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

    return (self.weight * hidden_states).to(input_dtype)

======

class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig):
    super().__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.head_dim = self.hidden_size // self.num_heads
    self.max_position_embeddings = config.max_position_embeddings

    if (self.head_dim * self.num_heads) != self.hidden_size:
        raise ValueError(
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
            f" and `num_heads`: {self.num_heads})."
        )
    self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
    self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
    self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]

    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(
            attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
        )

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2)
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

Expected behavior

may be float32 --> dtype

gante commented

Hey @park1200656 👋

Some operations degrade the quality of the outputs if not performed at a certain minimum precision. The softmax in the attention layer and the variance accumulation in RMSNorm performed in FP32 are two examples of that :) Related read: this issue


Following our issues guidelines, we reserve GitHub issues for bugs in the repository and/or feature requests. For any other matters, we'd like to invite you to use our forum 🤗 If this is your first issue with us, check this guide.

I had the same question yesterday. Can we make it optional? At least softmax

BF16 is good enough. And by "good enough" I mean it "not crashes at long context at my laptop's 3080TI " and "return values are the same anyway, instability might be overstated"

Example. Making it optional:

diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 24231c3f7..230e5333c 100755
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -228,8 +228,12 @@ class LlamaAttention(nn.Module):
                 attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
             )
 
-        # upcast attention to fp32
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        # optionally upcast attention to fp32
+        if self.config.use_attn_upcast:
+            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        else:
+            attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)

Test script:

from transformers import AutoModelForCausalLM
import torch
import sys

model = AutoModelForCausalLM.from_pretrained("./models/open_llama_3b/", torch_dtype=torch.bfloat16).cuda()
model.config.use_attn_upcast = "--no-oom" not in sys.argv
print("Predict that OOM will happen: ", model.config.use_attn_upcast)

input_ids = torch.arange(20)[None].cuda()
print(model(input_ids).logits.mean(-1))

input_ids = torch.arange(1000)[None].cuda()
print(model(input_ids).logits.mean())

With upcast removed

$  python demo_py.py --no-oom

Predict that OOM will happen:  False
tensor([[-9.0000, -6.0938, -1.8281, -7.7812, -7.5000, -7.5000, -7.6250, -7.7500,
         -7.1250, -7.0000, -7.7188, -7.5625, -6.9688, -5.5312, -6.1562, -6.5312,
         -7.5938, -7.0000, -7.1875, -6.8750]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<MeanBackward1>)
tensor(-6.9062, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

With upcast:

$ python demo_py.py 

Predict that OOM will happen:  True
tensor([[-9.0000, -6.0938, -1.8281, -7.7812, -7.5000, -7.5000, -7.6250, -7.7500,
         -7.1250, -7.0000, -7.7188, -7.5625, -6.9688, -5.5312, -6.1562, -6.5312,
         -7.5938, -7.0000, -7.1875, -6.8750]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<MeanBackward1>)
Traceback (most recent call last):
  File "/home/fella/src/llama/text-generation-webui/demo_py.py", line 14, in <module>
    print(model(input_ids).logits.mean())
          ^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 690, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 580, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 295, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 232, in forward
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/functional.py", line 1845, in softmax
    ret = input.softmax(dim, dtype=dtype)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 124.00 MiB (GPU 0; 15.74 GiB total capacity; 14.83 GiB already allocated; 134.38 MiB free; 15.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
gante commented

@Maykeye we have other options to reduce the memory footprint at inference time -- have you tried playing with our support for 4-bit inference? On a 3080 TI you may be able to run the 7B LLaMA model this way :)

Yes and quantized models produce noticeably different results.

gante commented

In general, lowering the precision of these operations will have a more significant impact on downstream performance (take it from the person that initially added the upcast at Meta).

Since we have other memory reduction strategies, we will not add the flag you're proposing. (Still, the code is open-source, feel free to fork transformers and keep your changes 🤗 )

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.