TylerYep/torchinfo

Inaccurate number of parameters at output layer

John-CYHui opened this issue · 0 comments

Describe the bug
Model takes into account unused layers when counting number of parameters

To Reproduce
Steps to reproduce the behavior:

class VanillaLeNet5(nn.Module):
    """
    Original implementation of LeNet5 paper
    Args:
        nn (_type_): _description_
    """
    def __init__(self, in_channels=1, number_classes=10):
        super(VanillaLeNet5, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5, stride=1)

        # Conv2 layers
        self.conv2 = Conv2Layer()
        
        # Pooling layer weights and biases
        self.pool = PoolingLayer(in_channels=6)
        self.pool2 = PoolingLayer(in_channels=16)
        
        # Fully connected layers
        self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=number_classes)
        
        # RBF layer
        self.rbf = RBFLayer(in_features=84, out_features=10)
        
        # Define image transformation
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize(size=[32, 32])]
        )
        
    def tanh(self, x):
        A = 1.7159
        S = 2/3
        return A * torch.tanh(S*x)
        
    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.tanh(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.tanh(x)
        x = self.pool2(x)

        # Flatten the output from the previous layer
        x = torch.flatten(x, 1)  # Keep batch dimension

        # Fully connected layers
        x = self.fc1(x)
        x = self.tanh(x)
        x = self.fc2(x)
        x = self.tanh(x)
        x = self.rbf(x)
        return x

class PoolingLayer(nn.Module):
    def __init__(self, in_channels=0):
        super(PoolingLayer, self).__init__()
        self.weights = nn.Parameter(torch.randn(in_channels))
        self.bias = nn.Parameter(torch.randn(in_channels))
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        
    def tanh(self, x):
        A = 1.7159
        S = 2/3
        return A * torch.tanh(S*x)
    
    def forward(self, x):
        # Compute the output tensor
        x = self.pool(x)
        x = x * self.weights.view(1,-1,1,1) + self.bias.view(1,-1,1,1)
        x = self.tanh(x)
        
        return x

class Conv2Layer(nn.Module):
    def __init__(self):
        super(Conv2Layer, self).__init__()
        self.conv2_1 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
        self.conv2_2 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
        self.conv2_3 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
        self.conv2_4 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
        self.conv2_5 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
        self.conv2_6 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
        self.conv2_7 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_8 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_9 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_10 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_11 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_12 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_13 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_14 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_15 = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=5, stride=1)
        self.conv2_16 = nn.Conv2d(in_channels=6, out_channels=1, kernel_size=5, stride=1)
    
    def forward(self, x):
        # Break of symmetry
        x1 = x[:, 0:3, :, :]
        x2 = x[:, 1:4, :, :]
        x3 = x[:, 2:5, :, :]
        x4 = x[:, 3:6, :, :]
        x5 = torch.concat((x[:, 4:6, :, :],x[:, 0:1, :, :]),dim=1)
        x6 = torch.concat((x[:, 5:6, :, :],x[:, 0:2, :, :]),dim=1)
        x7 = x[:, 0:4, :, :]
        x8 = x[:, 1:5, :, :]
        x9 = x[:, 2:6, :, :]
        x10 = torch.concat((x[:, 3:6, :, :],x[:, 0:1, :, :]),dim=1)
        x11 = torch.concat((x[:, 4:6, :, :],x[:, 0:2, :, :]),dim=1)
        x12 = torch.concat((x[:, 5:6, :, :],x[:, 0:3, :, :]),dim=1)
        x13 = torch.concat((x[:, 3:5, :, :],x[:, 0:2, :, :]),dim=1)
        x14 = torch.concat((x[:, 4:6, :, :],x[:, 1:3, :, :]),dim=1)
        x15 = torch.concat((x[:, 0:1, :, :],x[:, 2:4, :, :],x[:, 5:6, :, :]),dim=1)
        x16 = x
        
        x1 = self.conv2_1(x1)
        x2 = self.conv2_2(x2)
        x3 = self.conv2_3(x3)
        x4 = self.conv2_4(x4)
        x5 = self.conv2_5(x5)
        x6 = self.conv2_6(x6)
        x7 = self.conv2_7(x7)
        x8 = self.conv2_8(x8)
        x9 = self.conv2_9(x9)
        x10 = self.conv2_10(x10)
        x11 = self.conv2_11(x11)
        x12 = self.conv2_12(x12)
        x13 = self.conv2_13(x13)
        x14 = self.conv2_14(x14)
        x15 = self.conv2_15(x15)
        x16 = self.conv2_16(x16)
        
        x = torch.cat((x1, x2, x3, x4, x5, x6, x7,x8, x9, x10, x11, x12, x13, x14, x15, x16), dim=1)
        return x

class RBFLayer(nn.Module):
    def __init__(self, in_features=84, out_features=10):
        super(RBFLayer, self).__init__()
        self.weights = torch.randn((in_features,out_features)).cuda()
        self.in_features = in_features
        self.out_features = out_features
        
    def forward(self, x):
        size = (x.size(0), self.in_features, self.out_features)
        x = x.unsqueeze(2).expand(size)
        c = self.weights.unsqueeze(0).expand(size)
        x = (x - c).pow(2).sum(1)
        return x

if __name__ == "__main__":
    lenet = VanillaLeNet5()
    batch_size = 16
    summary(lenet, (batch_size, 1, 32, 32))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
VanillaLeNet5                            [16, 10]                  850
├─Conv2d: 1-1                            [16, 6, 28, 28]           156
├─PoolingLayer: 1-2                      [16, 6, 14, 14]           12
│    └─AvgPool2d: 2-1                    [16, 6, 14, 14]           --
├─Conv2Layer: 1-3                        [16, 16, 10, 10]          --
│    └─Conv2d: 2-2                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-3                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-4                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-5                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-6                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-7                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-8                       [16, 1, 10, 10]           101
│    └─Conv2d: 2-9                       [16, 1, 10, 10]           101
│    └─Conv2d: 2-10                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-11                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-12                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-13                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-14                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-15                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-16                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-17                      [16, 1, 10, 10]           151
├─PoolingLayer: 1-4                      [16, 16, 5, 5]            32
│    └─AvgPool2d: 2-18                   [16, 16, 5, 5]            --
├─Linear: 1-5                            [16, 120]                 48,120
├─Linear: 1-6                            [16, 84]                  10,164
├─RBFLayer: 1-7                          [16, 10]                  --
==========================================================================================
Total params: 60,850
Trainable params: 60,850
Non-trainable params: 0
Total mult-adds (M): 5.32
==========================================================================================
Input size (MB): 0.07
Forward/backward pass size (MB): 0.83
Params size (MB): 0.24
Estimated Total Size (MB): 1.14
==========================================================================================

Expected behavior
I expect there are no trainable parameters at the output layer, which was showing as 850. If I comment out self.fc3 (which was never used), this fixed the issue.

class VanillaLeNet5(nn.Module):
    """
    Original implementation of LeNet5 paper
    Args:
        nn (_type_): _description_
    """
    def __init__(self, in_channels=1, number_classes=10):
        super(VanillaLeNet5, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5, stride=1)

        # Conv2 layers
        self.conv2 = Conv2Layer()
        
        # Pooling layer weights and biases
        self.pool = PoolingLayer(in_channels=6)
        self.pool2 = PoolingLayer(in_channels=16)
        
        # Fully connected layers
        self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        #self.fc3 = nn.Linear(in_features=84, out_features=number_classes)
        
        # RBF layer
        self.rbf = RBFLayer(in_features=84, out_features=10)
        
        # Define image transformation
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Resize(size=[32, 32])]
        )
        
    def tanh(self, x):
        A = 1.7159
        S = 2/3
        return A * torch.tanh(S*x)
        
    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.tanh(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.tanh(x)
        x = self.pool2(x)

        # Flatten the output from the previous layer
        x = torch.flatten(x, 1)  # Keep batch dimension

        # Fully connected layers
        x = self.fc1(x)
        x = self.tanh(x)
        x = self.fc2(x)
        x = self.tanh(x)
        x = self.rbf(x)
        return x
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
VanillaLeNet5                            [16, 10]                  --
├─Conv2d: 1-1                            [16, 6, 28, 28]           156
├─PoolingLayer: 1-2                      [16, 6, 14, 14]           12
│    └─AvgPool2d: 2-1                    [16, 6, 14, 14]           --
├─Conv2Layer: 1-3                        [16, 16, 10, 10]          --
│    └─Conv2d: 2-2                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-3                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-4                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-5                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-6                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-7                       [16, 1, 10, 10]           76
│    └─Conv2d: 2-8                       [16, 1, 10, 10]           101
│    └─Conv2d: 2-9                       [16, 1, 10, 10]           101
│    └─Conv2d: 2-10                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-11                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-12                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-13                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-14                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-15                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-16                      [16, 1, 10, 10]           101
│    └─Conv2d: 2-17                      [16, 1, 10, 10]           151
├─PoolingLayer: 1-4                      [16, 16, 5, 5]            32
│    └─AvgPool2d: 2-18                   [16, 16, 5, 5]            --
├─Linear: 1-5                            [16, 120]                 48,120
├─Linear: 1-6                            [16, 84]                  10,164
├─RBFLayer: 1-7                          [16, 10]                  --
==========================================================================================
Total params: 60,000
Trainable params: 60,000
Non-trainable params: 0
Total mult-adds (M): 5.32
==========================================================================================
Input size (MB): 0.07
Forward/backward pass size (MB): 0.83
Params size (MB): 0.24
Estimated Total Size (MB): 1.14
==========================================================================================

Screenshots
image