InstanceNorm not working
pimdh opened this issue · 2 comments
pimdh commented
Hi! Thanks for this library :)
I'm trying to use InstanceNorm and it appears there's a bug.
When I run the following
irreps = BalancedIrreps(3, 20)
norm = InstanceNorm(irreps)
x = torch.randn(9, 20)
batch = torch.zeros(9, dtype=torch.long)
norm(x, batch)
I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[7], line 8
6 x = torch.randn(9, 10)
7 batch = torch.zeros(9, dtype=torch.long)
----> 8 norm(x, batch)
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/segnn/segnn/instance_norm.py:85, in InstanceNorm.forward(self, input, batch)
82 # For scalars first compute and subtract the mean
83 if ir.l == 0:
84 # Compute the mean
---> 85 field_mean = global_mean_pool(field, batch).reshape(-1, mul, 1) # [batch, mul, 1]]
86 # Subtract the mean
87 field = field - field_mean[batch]
File /usr/local/lib/python3.8/dist-packages/torch_geometric/nn/pool/glob.py:63, in global_mean_pool(x, batch, size)
61 return x.mean(dim=dim, keepdim=x.dim() <= 2)
62 size = int(batch.max().item() + 1) if size is None else size
---> 63 return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')
File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:81, in scatter(src, index, dim, dim_size, reduce)
78 count.scatter_add_(0, index, src.new_ones(src.size(dim)))
79 count = count.clamp(min=1)
---> 81 index = broadcast(index, src, dim)
82 out = src.new_zeros(size).scatter_add_(dim, index, src)
84 return out [/](https://vscode-remote+ssh-002dremote-002bgatr.vscode-resource.vscode-cdn.net/) broadcast(count, out, dim)
File /usr/local/lib/python3.8/dist-packages/torch_geometric/utils/scatter.py:21, in broadcast(src, ref, dim)
19 size = [1] * ref.dim()
20 size[dim] = -1
---> 21 return src.view(size).expand_as(ref)
RuntimeError: The expanded size of the tensor (10) must match the existing size (9) at non-singleton dimension 1. Target sizes: [9, 10, 1]. Tensor sizes: [1, 9, 1]
It appears this is because global_mean_pool
from pytorch geometric does not support more than 2 dimensions. The solution could be to replace global_mean_pool(field, batch)
with global_mean_pool(field.view(-1, mul), batch)
.
Cheers,
Pim
RobDHess commented
Hi,
Thanks for alerting us to this, I will fix it in the near future.
Cheers,
Rob