Glaciohound/LM-Infinite

kv_seq_len bug?

chenlidar opened this issue · 0 comments

if kv_seq_len > local_branch + global_branch and use_lambda_mask:
            past_key_value = (
                torch.cat([
                    key_states[..., :global_branch, :],
                    key_states[..., -local_branch:, :],
                ], dim=-2),
                torch.cat([
                    value_states[..., :global_branch, :],
                    value_states[..., -local_branch:, :],
                ], dim=-2),
                key_position_ids[..., :local_branch + global_branch]
            ) if use_cache else None

Code in models/llama.py lines 144-155 does not update the kv_seq_len, but updates the past_key_value?