pkunlp-icler/FastV

Qwen没有明显提速

Closed this issue · 3 comments

将src/transformers/src/transformers/models/llama/modeling_llama.py文件中,与fastv相关的代码,写到了modeling_qwen2.py中以后,进行测试。k=3,r=0.75时,仅有略微的提速;k=2,r=0.5时,速度甚至比不用fastv更慢。请问是什么原因呢?能否提供一个qwen的demo?

测试时,设置的FASTV_INPLACE=False,如果FASTV_INPLACE=True,cuda会报错...

你好,FASTV_INPLACE=False 时是基于修改 attention mask 进行的模拟token pruning,实际计算量不变。 FASTV_INPLACE=True 的时候,是真实丢弃tokens,有计算量的减少,会有明显提速,可以参考llava, latency 相关测试需要使用 FASTV_INPLACE=True。关于cuda_error, 需要对system prompt length,image token length等值基于QwenVL进行修改, 请问你修改了这部分数据吗。Qwen相关demo开源正在准备中。

非常感谢你的回复。image token length我做了修改,这个值,我理解是通过projector以后的features.shape[1],不知道对不对?另外,system prompt length,这个我没有做修改,还是写的5(如果不对,麻烦告诉一下正确值是多少)。

当FASTV_INPLACE=True时,调试时发现,裁剪tokens以后,虽然,position_id的维度是裁剪后的维度,但是里面的数值是有问题的。例如:position_id是(1,123,1024)维度,但是里面的最大值是超过123的,导致后续的tensor维度有问题。以下是FASTV_inplace=True的代码:

if FASTV_inplace:
if layer_idx<FASTV_k:
new_attention_mask = attention_mask

elif layer_idx==FASTV_k:
    # compute pruned tokens, generate fastv sign
    last_layer_attention = layer_outputs[1]
    # compute average attention over different head
    last_layer_attention_avg = torch.mean(last_layer_attention, dim=1)[0]
    # generate new attention mask based on the average attention, sample the top ATTENTION_RANK tokens with highest attention
    last_layer_attention_avg_last_tok = last_layer_attention_avg[-1]
    # get the attention in image token
    last_layer_attention_avg_last_tok_image = last_layer_attention_avg_last_tok[FASTV_image_token_start_index:FASTV_image_token_start_index+FASTV_image_token_length]
    # get the indexs of the top ATTENTION_RANK tokens
    top_attention_rank_index = last_layer_attention_avg_last_tok_image.topk(round(FASTV_image_token_length*(1-FASTV_r))).indices + FASTV_image_token_start_index
    # keep index
    keep_indexs = torch.cat( (torch.arange(FASTV_image_token_start_index,device=device), top_attention_rank_index, torch.arange(FASTV_image_token_start_index+FASTV_image_token_length,seq_length_with_past,device=device)))
    # sort index
    keep_indexs = keep_indexs.sort().values
    # update seq length
    new_seq_length = keep_indexs.shape[0]
    # filter hidden states
    hidden_states = hidden_states[:,keep_indexs,:]
    # update position ids
    position_ids = keep_indexs.unsqueeze(0)
    # update attention mask
    new_attention_mask = _prepare_4d_causal_attention_mask(
        None, (batch_size, new_seq_length), inputs_embeds, 0
    )

if layer_idx == FASTV_k - 1:
output_attentions = True
else:
output_attentions = False

layer_outputs = decoder_layer(
hidden_states,
attention_mask=new_attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)