sksq96/pytorch-summary

torch.Cat() isn't displayed in the summary, inner contents of a block are displayed.

gchhablani opened this issue · 0 comments

My forward method is -


    def forward(self,x,skip):
        x = self.upSamp(x)
        print(x.shape)
        x = self.convRelu1(x)
        print(x.shape)
        x = torch.cat((x,skip),1)
        print(x.shape)
        return x

I get the following output for summary :

torch.Size([2, 64, 54, 54])
torch.Size([2, 128, 53, 53])
torch.Size([2, 160, 53, 53])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
          Upsample-1           [-1, 64, 54, 54]               0
            Conv2d-2          [-1, 128, 53, 53]          32,896
              ReLU-3          [-1, 128, 53, 53]               0
          ConvReLU-4          [-1, 128, 53, 53]               0
================================================================

Clearly, the method ignores torch.cat() inside the forward method.

Also, it prints the name of the block (ConvReLU) after the inner components (Conv and Relu), in which case one of them should not be there (either the block name or the components).