kv_seq_len bug?
chenlidar opened this issue · 0 comments
chenlidar commented
if kv_seq_len > local_branch + global_branch and use_lambda_mask:
past_key_value = (
torch.cat([
key_states[..., :global_branch, :],
key_states[..., -local_branch:, :],
], dim=-2),
torch.cat([
value_states[..., :global_branch, :],
value_states[..., -local_branch:, :],
], dim=-2),
key_position_ids[..., :local_branch + global_branch]
) if use_cache else None
Code in models/llama.py lines 144-155 does not update the kv_seq_len, but updates the past_key_value?