SparseAttention doesn't support cache mechanism will cause unexpected keyword argument bug
jzq2000 opened this issue · 3 comments
jzq2000 commented
Hi! At the moment,SparseAttention
class inherits from Attention
but it does not support cache mechanism. I guess it is the reason of the unexpected keyword argument
bug.
CODE
dalle_pytorch/attention.py
line 366
def forward(self, x, mask = None, rotary_pos_emb = None):
dalle_pytorch/transformer.py
line 60
class CachedAs(nn.Module):
"""
A wrapper that defines a key for the inference cache.
"""
def __init__(self, cache_key, fn):
super().__init__()
self.cache_key = cache_key
self.fn = fn
def forward(self, x, *, cache=None, **kwargs):
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)
and line 279
if isinstance(attn, Attention):
attn = CachedAs(f'attn_{ind}', attn)
else:
# at the moment, other attention classes don't support cache
attn = NonCached(attn)
isseebx123 commented
It seems that the problem still occurs. Have you ever solved this problem?
jzq2000 commented
Not yet. Instead, I choose to use the sparse attention mentioned in the paper(e.g. axial_col)
hmbenhaim commented
i'm still getting this bug