I have a question in the source code called modeling_llama.py
park1200656 opened this issue · 6 comments
System Info
path : "src/transformers/models/llama/modeling_llama.py"
Line 85 and 232 of this code contains float32 as a constant.
I think, it looks like a bug. Or is there another reason?
Thanks.
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
class LlamaRMSNorm(nn.Module):
def init(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().init()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
======
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
Expected behavior
may be float32 --> dtype
Hey @park1200656 👋
Some operations degrade the quality of the outputs if not performed at a certain minimum precision. The softmax in the attention layer and the variance accumulation in RMSNorm performed in FP32 are two examples of that :) Related read: this issue
Following our issues guidelines, we reserve GitHub issues for bugs in the repository and/or feature requests. For any other matters, we'd like to invite you to use our forum 🤗 If this is your first issue with us, check this guide.
I had the same question yesterday. Can we make it optional? At least softmax
BF16 is good enough. And by "good enough" I mean it "not crashes at long context at my laptop's 3080TI " and "return values are the same anyway, instability might be overstated"
Example. Making it optional:
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 24231c3f7..230e5333c 100755
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -228,8 +228,12 @@ class LlamaAttention(nn.Module):
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
)
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ # optionally upcast attention to fp32
+ if self.config.use_attn_upcast:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ else:
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
Test script:
from transformers import AutoModelForCausalLM
import torch
import sys
model = AutoModelForCausalLM.from_pretrained("./models/open_llama_3b/", torch_dtype=torch.bfloat16).cuda()
model.config.use_attn_upcast = "--no-oom" not in sys.argv
print("Predict that OOM will happen: ", model.config.use_attn_upcast)
input_ids = torch.arange(20)[None].cuda()
print(model(input_ids).logits.mean(-1))
input_ids = torch.arange(1000)[None].cuda()
print(model(input_ids).logits.mean())
With upcast removed
$ python demo_py.py --no-oom
Predict that OOM will happen: False
tensor([[-9.0000, -6.0938, -1.8281, -7.7812, -7.5000, -7.5000, -7.6250, -7.7500,
-7.1250, -7.0000, -7.7188, -7.5625, -6.9688, -5.5312, -6.1562, -6.5312,
-7.5938, -7.0000, -7.1875, -6.8750]], device='cuda:0',
dtype=torch.bfloat16, grad_fn=<MeanBackward1>)
tensor(-6.9062, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
With upcast:
$ python demo_py.py
Predict that OOM will happen: True
tensor([[-9.0000, -6.0938, -1.8281, -7.7812, -7.5000, -7.5000, -7.6250, -7.7500,
-7.1250, -7.0000, -7.7188, -7.5625, -6.9688, -5.5312, -6.1562, -6.5312,
-7.5938, -7.0000, -7.1875, -6.8750]], device='cuda:0',
dtype=torch.bfloat16, grad_fn=<MeanBackward1>)
Traceback (most recent call last):
File "/home/fella/src/llama/text-generation-webui/demo_py.py", line 14, in <module>
print(model(input_ids).logits.mean())
^^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 690, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 580, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 295, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 232, in forward
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/fella/src/sd/sd/lib/python3.11/site-packages/torch/nn/functional.py", line 1845, in softmax
ret = input.softmax(dim, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 124.00 MiB (GPU 0; 15.74 GiB total capacity; 14.83 GiB already allocated; 134.38 MiB free; 15.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
@Maykeye we have other options to reduce the memory footprint at inference time -- have you tried playing with our support for 4-bit inference? On a 3080 TI you may be able to run the 7B LLaMA model this way :)
Yes and quantized models produce noticeably different results.
In general, lowering the precision of these operations will have a more significant impact on downstream performance (take it from the person that initially added the upcast at Meta).
Since we have other memory reduction strategies, we will not add the flag you're proposing. (Still, the code is open-source, feel free to fork transformers
and keep your changes 🤗 )
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.