torch.Cat() isn't displayed in the summary, inner contents of a block are displayed.
gchhablani opened this issue · 0 comments
gchhablani commented
My forward method is -
def forward(self,x,skip):
x = self.upSamp(x)
print(x.shape)
x = self.convRelu1(x)
print(x.shape)
x = torch.cat((x,skip),1)
print(x.shape)
return x
I get the following output for summary :
torch.Size([2, 64, 54, 54])
torch.Size([2, 128, 53, 53])
torch.Size([2, 160, 53, 53])
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Upsample-1 [-1, 64, 54, 54] 0
Conv2d-2 [-1, 128, 53, 53] 32,896
ReLU-3 [-1, 128, 53, 53] 0
ConvReLU-4 [-1, 128, 53, 53] 0
================================================================
Clearly, the method ignores torch.cat() inside the forward method.
Also, it prints the name of the block (ConvReLU) after the inner components (Conv and Relu), in which case one of them should not be there (either the block name or the components).