Fix get_usable_length for StaticCache
guangy10 opened this issue · 2 comments
System Info
transformers: 4.45.0.dev0
torch: 2.5.0.dev20240716+cpu
Who can help?
@ArthurZucker
@gante
@zucchini-nlp
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
It's a bug only affect the new workflow "Export to ExecuTorch" that we're trying to enable. The method get_usable_length
should be override for StaticCache
where recompile, resizing or evicting the existing cache entry won't be applicable. This is because unlike eager or torch.compile
d model, exported model will be running in a non-python env where recompiling from the eager source isn't available.
Expected behavior
The generation
length using the exported artifact should not exceed the maximal cache length because the model and the size of the its cache are exported statically. When get_usable_length
returns 0, it should terminate the generation.
The get_usable_length
is something that we are deprecating anyways, but good to note. StaticCache should not even have it!
@ArthurZucker Because @helunwencser in our team is working on Phi-3-mini, and in Phi-3's modeling code, get_usable_length
is used in several places like here or here. The modeling code itself is fine because doesn't make assumption about the type of cache being used. However, when comes to trace the model via torch.export
, the default get_usable_length
will trigger evicting old cache entries hence new attentions from that point will be incorrect. I think it would make sense to override it for StaticCache
to avoid misbehavior in the short-term before it can be fully deprecated.