Support forward with multiple arguments
joaolcguerreiro opened this issue · 4 comments
Imagine I have a module like this:
class Model(nn.Module):
def __init__(self, generator, discriminator):
super(Model, self).__init__()
# Define Generator
self.generator = generator
# Define Discriminator
self.discriminator = discriminator
def forward(self, lr, hr):
gen = self.generator(lr)
return gen, self.discriminator(gen), self.discriminator(hr)
If I want to call summary(model, input_size=..., depth=1)
what should the input_size look like? Is it supported?
I believe the summary function could handle a input_size in a list meaning the forward will receive as many arguments as element in the list passed.
It is not supported via input_size
, but you can easily circumvent that by using the input_data
-argument. For example:
from torchview import summary
generator, discriminator = ...
model = Model(generator, discriminator)
lr = torch.randn(1, 2, 3, 4, 5) # whatever lr is
hr = torch.randn(2, 5, 2, 5) # whatever hr is
summary(model, input_data=(lr, hr), depth=1)
Because I don't know what the generator
or discriminator
is, or what lr
and hr
are, I cannot be more specific, and I don't know for sure if this will work, but in principle, you can just generate pseudo-data and give that to summary
. By packaging multiple inputs into a single tuple
or list
, you can handle models like yours.
If you try that and it still fails, then you can write again. If so, I would need more detail to look into it more closely.
@snimu ,I meet the same error,my code like this:
`def test_batch(self, img, label):
self.model.eval()
with torch.no_grad():
label_input, label_length, label_target = self.converter.test_encode(label)
if self.use_gpu:
img = img.cuda()
#print(img.shape)
label_input = label_input.cuda()
if self.need_text:
pred = self.model((img, label_input))
from torchinfo import summary
print(img.shape,label_input.shape)
lr = torch.randn(288,1,32,100)
hr = torch.randn(288,1)
summary(self.model,input_data=(lr,hr),depth=1)
else:
pred = self.model((img,))
pred, prob = self.postprocess(pred, self.postprocess_cfg)
self.metric.measure(pred, prob, label)
self.backup_metric.measure(pred, prob, label)
`
but I got this error:
torch.Size([288, 1, 32, 100]) torch.Size([288, 1])
Traceback (most recent call last):
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 288, in forward_pass
_ = model.to(device)(*x, **kwargs)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "tools/test.py", line 46, in
main()
File "tools/test.py", line 42, in main
runner()
File "tools/../vedastr/runners/test_runner.py", line 49, in call
self.test_batch(img, label)
File "tools/../vedastr/runners/test_runner.py", line 32, in test_batch
summary(self.model,input_data=(lr,hr),depth=1)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 219, in summary
model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 300, in forward_pass
) from e
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []
can you help me? thanks.
It looks like your model's forward
-method only takes one input, but you have given it two.
Here is the call that effectively happens inside summary
, given your arguments:
# Setup:
model = Model()
# The call:
model(lr, hr)
You can see this from the following part of the error message: TypeError: forward() takes 2 positional arguments but 3 were given
. The three arguments that were given are self
, lr
, and hr
(self
is automatically given). You have not provided code for your model, but it seems clear to me that your model's forward-pass only takes a single argument besides self
.
@joaolcguerreiro Is your issue resolved?