Can I use summary with tuple input? (len(inputs) = 3)
Closed this issue · 2 comments
AppleJoker94 commented
- my forward function is below
def forward(self, x: Tensor) -> Tensor:
l0 = self.conv1(x[0])
l1 = self.conv2(l0)
l2 = self.conv3(x[1])
feature = torch.cat([l1, l2, x[2]],1)
out = self.oce_block(feature)
return out
- and my input is below
print("inputs : ", len(inputs))
print("A : ", inputs[0].size())
print("B : ", inputs[1].size())
print("C : ", inputs[2].size())
A : torch.Size([2, 128, 56, 56])
B : torch.Size([2, 256, 28, 28])
C : torch.Size([2, 512, 14, 14])
- but I can't use summary
from torchsummary import summary
summary(bn,([128, 56, 56],[256, 28, 28],[512, 14, 14]), batch_size=64)
TypeError Traceback (most recent call last)
Cell In[28], line 3
1 from torchsummary import summary
----> 3 summary(bn, input_data=[inputs[0], inputs[1],inputs[2]])
TypeError: summary() got an unexpected keyword argument 'input_data'
Can I use summary with tuple input? (len(inputs) = 3)
snimu commented
inputs = torch.tensor([128, 56, 56],[256, 28, 28],[512, 14, 14])
model = ... # your model
summary(model, input_data=inputs)
Since the forward
function just takes a single tensor that is then split into three in there, this should work.
AppleJoker94 commented
I changed it, but sadly, it seems difficult to apply to Tuple.
So I modified the forward function and the problem was solved.
Thank you.
def forward(self, x1, x2, x3) -> Tensor:
l0 = self.conv1(x1)
l1 = self.conv2(l0)
l2 = self.conv3(x2)
feature = torch.cat([l1, l2, x3],1)
out = self.oce_block(feature)
return out