ispamm/hTorch

QBatchNorm2d RuntimeError 'shape invalid'

Closed this issue · 2 comments

Hi,

It's me again. I'm running similar code to that I used in issue #8 but this time I'm using a ResNet model with batch normalization on the CIFAR-10 dataset (I used your custom collate_fn to add an addition channel to the images). The code for the model is given below:

class Quat_Block(nn.Module):
    """
    A quaternion ResNet block.
    """
    def __init__(self, in_channels: int, out_channels: int, downsample=False):
        super().__init__()

        stride = 2 if downsample else 1

        self.conv1 = layers.QConv2d(in_channels, out_channels, kernel_size=3,
                                    stride=stride, padding=1, bias=False)
        self.bn1 = layers.QBatchNorm2d(out_channels)

        self.conv2 = layers.QConv2d(out_channels, out_channels, kernel_size=3,
                                    stride=1, padding=1, bias=False)
        self.bn2 = layers.QBatchNorm2d(out_channels)

        # Shortcut connection
        if downsample or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                layers.QConv2d(in_channels, out_channels, kernel_size=1,
                               stride=2, bias=False),
                layers.QBatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        num_segments = 3
        filters_per_segment = [4, 8, 16]
        architecture = [(num_filters, num_segments) for num_filters in
                        filters_per_segment]

        # Initial convolutional layer.
        current_filters = architecture[0][0]
        self.conv = layers.QConv2d(1, current_filters, kernel_size=3, stride=1,
                                   padding=1, bias=False)
        self.bn = layers.QBatchNorm2d(current_filters)

        # ResNet blocks
        blocks = []
        for segment_index, (filters, num_blocks) in enumerate(architecture):
            for block_index in range(num_blocks):
                downsample = segment_index > 0 and block_index == 0
                blocks.append(Quat_Block(current_filters, filters, downsample))
                current_filters = filters

        self.blocks = nn.Sequential(*blocks)

        # Final fc layer.
        self.fc = layers.QLinear(architecture[-1][0], 10)
        self.abs = layers.QuaternionToReal(10)

    def forward(self, x):
        out = F.relu(self.bn(self.conv(x)))
        out = self.blocks(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return self.abs(out)

I'm getting the following error:

/home/sahel/Documents/code/quaternion_lth/htorch/htorch/layers.py:577: UserWarning: torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future PyTorch release.
L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1284.)
  ell = torch.cholesky(cov + self.eye, upper=True)
Traceback (most recent call last):
  File "train.py", line 47, in <module>
    H.train_model(
  File "/home/sahel/Documents/code/quaternion_lth/helper_methods.py", line 192, in train_model
    accuracy = test_model(model, testloader, device)
  File "/home/sahel/Documents/code/quaternion_lth/helper_methods.py", line 249, in test_model
    outputs = model(images)
  File "/home/sahel/Documents/code/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sahel/Documents/code/quaternion_lth/models/resnet.py", line 172, in forward
    out = F.relu(self.bn(self.conv(x)))
  File "/home/sahel/Documents/code/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sahel/Documents/code/quaternion_lth/htorch/htorch/layers.py", line 587, in forward
    weight = self.weight.view(4, 4, *shape)
RuntimeError: shape '[4, 4, 1, 4, 1, 1]' is invalid for input of size 16

I can't figure out if this is because of an error in my implementation of the model or because of an error in the QBatchNorm2d function.

Note: In the file layers.py, in line 545 you're missing an argument for the method init.constant_(), which is probably the first error that'll show up if you run the QBatchNorm2d function.

Thanks a lot for this! There was indeed a mistake on the BatchNorm, the channels in input should not be divided by 4. We pushed a fix for this (and the constant_ part), let me know if everything is alright 👍🏻

Yes, it's working fine now, thanks!