chinhsuanwu/coatnet-pytorch

dots + relative positon bias having issue with dimension mismatch in Transformer Block!!!

Opened this issue · 0 comments

I try to run the CoAtNet wihtout modifying anything,

def coatnet_0(): num_blocks = [1, 1, 1, 1, 1] # L channels = [64, 96, 192, 384, 768] # D return CoAtNet((50, 50), 3, num_blocks, channels, num_classes=3)

img = torch.randn(1, 3, 50, 50)

net = coatnet_0() out = net(img) print(out.shape)

The error is that the dot is computed from downsampled image,
and the relative bias position is calculated from original image, i do not think that atteniton mechanism is miscalcuting the dots and relative position bias.
i think the issue is from how attention is iplement in Transformer Block, somehow the attention in transformer block is not utilizing downsampled image for relative position bias, instead it is calculated based on original image.
how can i solve this issue?


RuntimeError Traceback (most recent call last)
Cell In[15], line 1
----> 1 out = net(img)
2 print(out.shape)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[10], line 25, in CoAtNet.forward(self, x)
23 x = self.s1(x)
24 x = self.s2(x)
---> 25 x = self.s3(x)
26 x = self.s4(x)
28 x = self.pool(x).view(-1, x.shape[1])

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[9], line 31, in Transformer.forward(self, x)
29 def forward(self, x):
30 if self.downsample:
---> 31 x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
32 else:
33 x = x + self.attn(x)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[4], line 8, in PreNorm.forward(self, x, **kwargs)
7 def forward(self, x, **kwargs):
----> 8 return self.fn(self.norm(x), **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None

Cell In[8], line 47, in Attention.forward(self, x)
43 relative_bias = self.relative_bias_table.gather(
44 0, self.relative_index.repeat(1, self.heads))
45 relative_bias = rearrange(
46 relative_bias, '(h w) c -> 1 c h w', h=self.ihself.iw, w=self.ihself.iw)
---> 47 dots = dots + relative_bias
49 attn = self.attend(dots)
50 out = torch.matmul(attn, v)

RuntimeError: The size of tensor a (16) must match the size of tensor b (9) at non-singleton dimension 3