TylerYep/torchinfo

Can I use summary with tuple input? (len(inputs) = 3)

Closed this issue · 2 comments

  1. my forward function is below
def forward(self, x: Tensor) -> Tensor:
        l0 = self.conv1(x[0])
        l1 = self.conv2(l0)
        l2 = self.conv3(x[1])
        feature = torch.cat([l1, l2, x[2]],1)
        
        out = self.oce_block(feature)
        return out
  1. and my input is below
print("inputs : ", len(inputs))
            print("A : ", inputs[0].size())
            print("B : ", inputs[1].size())
            print("C : ", inputs[2].size())

A :  torch.Size([2, 128, 56, 56])
B :  torch.Size([2, 256, 28, 28])
C :  torch.Size([2, 512, 14, 14])
  1. but I can't use summary
    from torchsummary import summary
    summary(bn,([128, 56, 56],[256, 28, 28],[512, 14, 14]), batch_size=64)

TypeError Traceback (most recent call last)
Cell In[28], line 3
1 from torchsummary import summary
----> 3 summary(bn, input_data=[inputs[0], inputs[1],inputs[2]])

TypeError: summary() got an unexpected keyword argument 'input_data'

Can I use summary with tuple input? (len(inputs) = 3)

snimu commented
inputs = torch.tensor([128, 56, 56],[256, 28, 28],[512, 14, 14])
model = ...  # your model
summary(model, input_data=inputs)

Since the forward function just takes a single tensor that is then split into three in there, this should work.

I changed it, but sadly, it seems difficult to apply to Tuple.

So I modified the forward function and the problem was solved.

Thank you.

def forward(self, x1, x2, x3) -> Tensor:
        l0 = self.conv1(x1)
        l1 = self.conv2(l0)
        l2 = self.conv3(x2)
        feature = torch.cat([l1, l2, x3],1)
        out = self.oce_block(feature)
        return out