lucidrains/DALLE-pytorch

SparseAttention doesn't support cache mechanism will cause unexpected keyword argument bug

jzq2000 opened this issue · 3 comments

Hi! At the moment,SparseAttention class inherits from Attention but it does not support cache mechanism. I guess it is the reason of the unexpected keyword argument bug.

CODE

dalle_pytorch/attention.py line 366

def forward(self, x, mask = None, rotary_pos_emb = None):

dalle_pytorch/transformer.py line 60

class CachedAs(nn.Module):
    """
    A wrapper that defines a key for the inference cache.
    """

    def __init__(self, cache_key, fn):
        super().__init__()
        self.cache_key = cache_key
        self.fn = fn

    def forward(self, x, *, cache=None, **kwargs):
        return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)

and line 279

            if isinstance(attn, Attention):
                attn = CachedAs(f'attn_{ind}', attn)
            else:
                # at the moment, other attention classes don't support cache
                attn = NonCached(attn)

It seems that the problem still occurs. Have you ever solved this problem?

Not yet. Instead, I choose to use the sparse attention mentioned in the paper(e.g. axial_col)

i'm still getting this bug