[Llava] Phi text model produces `ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8])` when using `past_key_values` in generate
xenova opened this issue · 8 comments
System Info
transformers
version: 4.38.2- Platform: Linux-6.1.58+-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.23.0
- Safetensors version: 0.4.3
- Accelerate version: 0.30.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.2.1+cu121 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.8.3 (cpu)
- Jax version: 0.4.26
- JaxLib version: 0.4.26
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Who can help?
@gante (generate) @susnato (phi implementation) @younesbelkada (llava implementation)
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Following the multi-round conversation tutorial from here, I put together this minimal reproduction to show how switching Llava to use a Phi text model (instead of e.g., llama) results in an error when reusing past key values.
Running:
from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration
# Load model and processor
# THIS WORKS
# model_id = "Xenova/tiny-random-LlavaForConditionalGeneration"
# model = LlavaForConditionalGeneration.from_pretrained(model_id)
# THIS DOESN'T WORK
model_id = "Xenova/tiny-random-LlavaForConditionalGeneration_phi"
model = LlavaForConditionalGeneration.from_pretrained(model_id, attn_implementation="eager")
processor = AutoProcessor.from_pretrained(model_id)
# Define inputs
prompt = "<image>Hi"
url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png?download=true"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image,
return_tensors="pt", padding=True)
# Generate w/o past_key_values
output = model.generate(
**inputs,
max_new_tokens=3,
return_dict_in_generate=True,
do_sample=False,
)
decoded = processor.batch_decode(
output["sequences"], skip_special_tokens=False)
# Prepare new inputs
new_inputs = processor(decoded, return_tensors="pt", padding=True)
# Generate w/ past_key_values
generate_ids = model.generate(
**new_inputs,
do_sample=False,
past_key_values=output['past_key_values'],
max_new_tokens=20,
)
print(f'{generate_ids=}')
decoded2 = processor.batch_decode(
generate_ids, skip_special_tokens=False)
print(f'{decoded2=}')
results in this error
Traceback (most recent call last):
File "/content/transformers.js/../test.py", line 39, in <module>
generate_ids = model.generate(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1544, in generate
return self.greedy_search(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2404, in greedy_search
outputs = self(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llava/modeling_llava.py", line 469, in forward
outputs = self.language_model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 1046, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 925, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 666, in forward
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 375, in forward
raise ValueError(
ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8])
Expected behavior
If you try with a llama model (e.g., here; see comments) it works correctly.
One way to get it running is to specify the vision inputs:
- new_inputs = processor(decoded, return_tensors="pt", padding=True)
+ new_inputs = processor(decoded, images=image, return_tensors="pt", padding=True)
(but it's still odd why llama works and phi doesn't without the vision inputs) 👀
Here's a full example of it working (but inefficiently):
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "xtuner/llava-phi-3-mini-hf"
prompt = "<|user|>\n<image>\nWhat are these?<|end|>\n<|assistant|>\n"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Generate w/o past_key_values
output = model.generate(
**inputs,
max_new_tokens=3, # Stop early so we can test out continuing it later
do_sample=False,
return_dict_in_generate=True,
)
print(processor.decode(output.sequences[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False))
# outputs "These are two"
# Generate w/ past_key_values
continued_ids = model.generate(
output.sequences,
pixel_values=inputs.pixel_values,
attention_mask=torch.ones_like(output.sequences),
do_sample=False,
past_key_values=output['past_key_values'],
max_new_tokens=20,
)
print(processor.decode(continued_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False))
# outputs "These are two cats sleeping on a pink couch.<|end|><|end|><|end|><|end|><|endoftext|>"
As I understand it, the vision encoder is still run and the inputs are still merged, even though they will be cropped out later and the past_key_values will be used.
Hey!
Llama models were the first ones to get new features like StaticCache or a method to _update_causal_mask
based on the input tensor. That is why the code fails for Phi and not Llama.
But the true error lies in the way we handle vision language models. Right now we expand the input embedding inside modeling file by concatenating image embeddings with text embeddings. When we try to continue generate, the past key values
hold tensors of much larger seq length than the input_text
we feed here, and therefore shapes of keys/values do not match the shape of the attention mask.
continued_ids = model.generate(
output.sequences,
do_sample=False,
past_key_values=output['past_key_values'],
max_new_tokens=20,
)
One way to get it running is to specify the vision inputs:
Yes, this way we can trigger expansion of inputs by concatenating image embeddings, so that the final shapes match with past cache.
Yes, the best solution is to move inputs expansion by dummy values into the processors. I believe @amyeroberts is working on it.
Thanks for the great explanation @zucchini-nlp! Indeed the differences between phi and llama confused me a bit (my use case is support for this in Transformers.js), so I've now taken this into account in my implementation.
Although, this is still problematic for when images are passed. I'll continue looking into this.
@zucchini-nlp this means updating LLMs to the new cache format will remove this error, correct?
@gante no, it will not as past key values still will not match the size of the input sequence. The code will fail on identifying the correct start and end for cache_positions