lucidrains/linear-attention-transformer

Autopadder doesn't work with LinearAttentionTransformer

jamarju opened this issue · 1 comments

Is Autopadder supposed to work with Linformer?

If I try this:

import torch
from linear_attention_transformer import LinearAttentionTransformer
from linear_attention_transformer.autopadder import Autopadder

model =  Autopadder(LinearAttentionTransformer(
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 8192,
    n_local_attn_heads = 4
)).cuda()

x = torch.randn(1, 8191, 512).cuda()
model(x) # (1, 8191, 512)

I get this:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-5a45a93503d7> in <module>
     12 
     13 x = torch.randn(1, 8191, 512).cuda()
---> 14 model(x) # (1, 8192, 512)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/autopadder.py in forward(self, x, **kwargs)
     53             kwargs.update(input_mask=new_mask)
     54 
---> 55         out = self.net(x, **kwargs)
     56 
     57         output_slice = slice(0, t) if not self.pad_left else slice(padding, None)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
    354 
    355     def forward(self, x, **kwargs):
--> 356         return self.layers(x, **kwargs)
    357 
    358 class LinearAttentionTransformerLM(nn.Module):

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/reversible.py in forward(self, x, **kwargs)
    147 
    148         for (f, g), (f_args, g_args) in layers_and_args:
--> 149             x = x + f(x, **f_args)
    150             x = x + g(x, **g_args)
    151         return x

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
     65     def forward(self, x, **kwargs):
     66         x = self.norm(x)
---> 67         return self.fn(x, **kwargs)
     68 
     69 class Chunk(nn.Module):

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, input_mask, context, context_mask, **kwargs)
    258 
    259         if has_local:
--> 260             local_out = self.local_attn(lq, lk, lv, input_mask = input_mask)
    261             out.append(local_out)
    262 

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/local_attention/local_attention.py in forward(self, q, k, v, input_mask)
    136         if input_mask is not None:
    137             h = b // input_mask.shape[0]
--> 138             input_mask = input_mask.reshape(-1, windows, window_size)
    139             mq = mk = input_mask
    140             mk = look_around(mk, pad_value=False, **look_around_kwargs)

RuntimeError: shape '[-1, 64, 128]' is invalid for input of size 4201983

I got rid of that error by doing this:

diff --git a/linear_attention_transformer/autopadder.py b/linear_attention_transformer/autopadder.py
index dd84663..d10927d 100644
--- a/linear_attention_transformer/autopadder.py
+++ b/linear_attention_transformer/autopadder.py
@@ -48,7 +48,10 @@ class Autopadder(nn.Module):
         x, padding = pad_to_multiple(x, self.pad_to, dim=self.pad_dim, pad_left=self.pad_left)

         if padding != 0:
-            offset = (0, padding) if not self.pad_left else (padding, 0)
+            if self.pad_dim == -1:
+                offset = (0, padding) if not self.pad_left else (padding, 0)
+            else:
+                offset = (0, 0, 0, padding) if not self.pad_left else (0, 0, padding, 0)
             new_mask = F.pad(input_mask, offset, value=False)
             kwargs.update(input_mask=new_mask)

But then I hit this:

import torch
from linear_attention_transformer import LinearAttentionTransformer
from linear_attention_transformer.autopadder import Autopadder

model =  Autopadder(LinearAttentionTransformer(
    dim = 128,
    heads = 4,
    depth = 1,
    max_seq_len = 256,
    n_local_attn_heads = 4
)).cuda()

x = torch.randn(1, 255, 512).cuda()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-3a1475fdc86c> in <module>
     12 
     13 x = torch.randn(1, 255, 512).cuda()
---> 14 model(x) # (1, 8191, 512)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/autopadder.py in forward(self, x, **kwargs)
     56             kwargs.update(input_mask=new_mask)
     57 
---> 58         out = self.net(x, **kwargs)
     59 
     60         output_slice = slice(0, t) if not self.pad_left else slice(padding, None)

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
    354 
    355     def forward(self, x, **kwargs):
--> 356         return self.layers(x, **kwargs)
    357 
    358 class LinearAttentionTransformerLM(nn.Module):

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/reversible.py in forward(self, x, **kwargs)
    147 
    148         for (f, g), (f_args, g_args) in layers_and_args:
--> 149             x = x + f(x, **f_args)
    150             x = x + g(x, **g_args)
    151         return x

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
     64         self.norm = nn.LayerNorm(dim)
     65     def forward(self, x, **kwargs):
---> 66         x = self.norm(x)
     67         return self.fn(x, **kwargs)
     68 

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/normalization.py in forward(self, input)
    168     def forward(self, input: Tensor) -> Tensor:
    169         return F.layer_norm(
--> 170             input, self.normalized_shape, self.weight, self.bias, self.eps)
    171 
    172     def extra_repr(self) -> Tensor:

/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/functional.py in layer_norm(input, normalized_shape, weight, bias, eps)
   2047     """
   2048     return torch.layer_norm(input, normalized_shape, weight, bias, eps,
-> 2049                             torch.backends.cudnn.enabled)
   2050 
   2051 

RuntimeError: Given normalized_shape=[128], expected input with shape [*, 128], but got input of size[1, 256, 512]

Thanks for this great repo!