Autopadder doesn't work with LinearAttentionTransformer
jamarju opened this issue · 1 comments
jamarju commented
Is Autopadder supposed to work with Linformer?
If I try this:
import torch
from linear_attention_transformer import LinearAttentionTransformer
from linear_attention_transformer.autopadder import Autopadder
model = Autopadder(LinearAttentionTransformer(
dim = 512,
heads = 8,
depth = 1,
max_seq_len = 8192,
n_local_attn_heads = 4
)).cuda()
x = torch.randn(1, 8191, 512).cuda()
model(x) # (1, 8191, 512)
I get this:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-5a45a93503d7> in <module>
12
13 x = torch.randn(1, 8191, 512).cuda()
---> 14 model(x) # (1, 8192, 512)
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/autopadder.py in forward(self, x, **kwargs)
53 kwargs.update(input_mask=new_mask)
54
---> 55 out = self.net(x, **kwargs)
56
57 output_slice = slice(0, t) if not self.pad_left else slice(padding, None)
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
354
355 def forward(self, x, **kwargs):
--> 356 return self.layers(x, **kwargs)
357
358 class LinearAttentionTransformerLM(nn.Module):
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/reversible.py in forward(self, x, **kwargs)
147
148 for (f, g), (f_args, g_args) in layers_and_args:
--> 149 x = x + f(x, **f_args)
150 x = x + g(x, **g_args)
151 return x
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
65 def forward(self, x, **kwargs):
66 x = self.norm(x)
---> 67 return self.fn(x, **kwargs)
68
69 class Chunk(nn.Module):
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, input_mask, context, context_mask, **kwargs)
258
259 if has_local:
--> 260 local_out = self.local_attn(lq, lk, lv, input_mask = input_mask)
261 out.append(local_out)
262
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/local_attention/local_attention.py in forward(self, q, k, v, input_mask)
136 if input_mask is not None:
137 h = b // input_mask.shape[0]
--> 138 input_mask = input_mask.reshape(-1, windows, window_size)
139 mq = mk = input_mask
140 mk = look_around(mk, pad_value=False, **look_around_kwargs)
RuntimeError: shape '[-1, 64, 128]' is invalid for input of size 4201983
I got rid of that error by doing this:
diff --git a/linear_attention_transformer/autopadder.py b/linear_attention_transformer/autopadder.py
index dd84663..d10927d 100644
--- a/linear_attention_transformer/autopadder.py
+++ b/linear_attention_transformer/autopadder.py
@@ -48,7 +48,10 @@ class Autopadder(nn.Module):
x, padding = pad_to_multiple(x, self.pad_to, dim=self.pad_dim, pad_left=self.pad_left)
if padding != 0:
- offset = (0, padding) if not self.pad_left else (padding, 0)
+ if self.pad_dim == -1:
+ offset = (0, padding) if not self.pad_left else (padding, 0)
+ else:
+ offset = (0, 0, 0, padding) if not self.pad_left else (0, 0, padding, 0)
new_mask = F.pad(input_mask, offset, value=False)
kwargs.update(input_mask=new_mask)
But then I hit this:
import torch
from linear_attention_transformer import LinearAttentionTransformer
from linear_attention_transformer.autopadder import Autopadder
model = Autopadder(LinearAttentionTransformer(
dim = 128,
heads = 4,
depth = 1,
max_seq_len = 256,
n_local_attn_heads = 4
)).cuda()
x = torch.randn(1, 255, 512).cuda()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-3a1475fdc86c> in <module>
12
13 x = torch.randn(1, 255, 512).cuda()
---> 14 model(x) # (1, 8191, 512)
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/autopadder.py in forward(self, x, **kwargs)
56 kwargs.update(input_mask=new_mask)
57
---> 58 out = self.net(x, **kwargs)
59
60 output_slice = slice(0, t) if not self.pad_left else slice(padding, None)
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
354
355 def forward(self, x, **kwargs):
--> 356 return self.layers(x, **kwargs)
357
358 class LinearAttentionTransformerLM(nn.Module):
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/reversible.py in forward(self, x, **kwargs)
147
148 for (f, g), (f_args, g_args) in layers_and_args:
--> 149 x = x + f(x, **f_args)
150 x = x + g(x, **g_args)
151 return x
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/linear_attention_transformer/linear_attention_transformer.py in forward(self, x, **kwargs)
64 self.norm = nn.LayerNorm(dim)
65 def forward(self, x, **kwargs):
---> 66 x = self.norm(x)
67 return self.fn(x, **kwargs)
68
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/normalization.py in forward(self, input)
168 def forward(self, input: Tensor) -> Tensor:
169 return F.layer_norm(
--> 170 input, self.normalized_shape, self.weight, self.bias, self.eps)
171
172 def extra_repr(self) -> Tensor:
/opt/conda/envs/sped37/lib/python3.7/site-packages/torch/nn/functional.py in layer_norm(input, normalized_shape, weight, bias, eps)
2047 """
2048 return torch.layer_norm(input, normalized_shape, weight, bias, eps,
-> 2049 torch.backends.cudnn.enabled)
2050
2051
RuntimeError: Given normalized_shape=[128], expected input with shape [*, 128], but got input of size[1, 256, 512]
Thanks for this great repo!
lucidrains commented
@jamarju no problem! I've fixed that issue in the latest commit! https://github.com/lucidrains/linear-attention-transformer/releases/tag/0.14.1