weight_norm layer prevents "Kernel Shape" output
Opened this issue · 0 comments
helion-du-mas-des-bourboux-thales commented
weight_norm layer prevents "Kernel Shape" output
Taking the example from the README
from torchsummaryX import summary
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.conv2 = torch.nn.utils.weight_norm(self.conv2)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
summary(Net(), torch.zeros((1, 1, 28, 28)))
returns the following
================================================================
Kernel Shape Output Shape Params Mult-Adds
Layer
0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 260.0 144.0k
1_conv2 - [1, 20, 8, 8] 5.04k 5.02k
2_conv2_drop - [1, 20, 8, 8] - -
3_fc1 [320, 50] [1, 50] 16.05k 16.0k
4_fc2 [50, 10] [1, 10] 510.0 500.0
----------------------------------------------------------------
Totals
Total params 21.86k
Trainable params 21.86k
Non-trainable params 0.0
Mult-Adds 165.52k
================================================================
But it should return the following instead:
================================================================
Kernel Shape Output Shape Params Mult-Adds
Layer
0_conv1 [1, 10, 5, 5] [1, 10, 24, 24] 260.0 144.0k
1_conv2 [10, 20, 5, 5] [1, 20, 8, 8] 5.04k 5.02k
2_conv2_drop - [1, 20, 8, 8] - -
3_fc1 [320, 50] [1, 50] 16.05k 16.0k
4_fc2 [50, 10] [1, 10] 510.0 500.0
----------------------------------------------------------------
Totals
Total params 21.86k
Trainable params 21.86k
Non-trainable params 0.0
Mult-Adds 165.52k
================================================================