TylerYep/torchinfo

Support forward with multiple arguments

joaolcguerreiro opened this issue · 4 comments

Imagine I have a module like this:

class Model(nn.Module):
    def __init__(self, generator, discriminator):
        super(Model, self).__init__()

        # Define Generator
        self.generator = generator

        # Define Discriminator
        self.discriminator = discriminator

    def forward(self, lr, hr):
        gen = self.generator(lr)

        return gen, self.discriminator(gen), self.discriminator(hr)

If I want to call summary(model, input_size=..., depth=1) what should the input_size look like? Is it supported?

I believe the summary function could handle a input_size in a list meaning the forward will receive as many arguments as element in the list passed.

snimu commented

It is not supported via input_size, but you can easily circumvent that by using the input_data-argument. For example:

from torchview import summary

generator, discriminator = ...
model = Model(generator, discriminator)
lr = torch.randn(1, 2, 3, 4, 5)  # whatever lr is
hr = torch.randn(2, 5, 2, 5)  # whatever hr is

summary(model, input_data=(lr, hr), depth=1)

Because I don't know what the generator or discriminator is, or what lr and hr are, I cannot be more specific, and I don't know for sure if this will work, but in principle, you can just generate pseudo-data and give that to summary. By packaging multiple inputs into a single tuple or list, you can handle models like yours.

If you try that and it still fails, then you can write again. If so, I would need more detail to look into it more closely.

@snimu ,I meet the same error,my code like this:
`def test_batch(self, img, label):

    self.model.eval()

    with torch.no_grad():
        label_input, label_length, label_target = self.converter.test_encode(label)
        if self.use_gpu:
            img = img.cuda()
            #print(img.shape)
            label_input = label_input.cuda()

        if self.need_text:
            pred = self.model((img, label_input))
            from torchinfo import summary
            print(img.shape,label_input.shape)
            lr = torch.randn(288,1,32,100)
            hr = torch.randn(288,1)
            summary(self.model,input_data=(lr,hr),depth=1)
        else:
            pred = self.model((img,))

        pred, prob = self.postprocess(pred, self.postprocess_cfg)
        self.metric.measure(pred, prob, label)
        self.backup_metric.measure(pred, prob, label)

`
but I got this error:
torch.Size([288, 1, 32, 100]) torch.Size([288, 1])
Traceback (most recent call last):
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 288, in forward_pass
_ = model.to(device)(*x, **kwargs)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "tools/test.py", line 46, in
main()
File "tools/test.py", line 42, in main
runner()
File "tools/../vedastr/runners/test_runner.py", line 49, in call
self.test_batch(img, label)
File "tools/../vedastr/runners/test_runner.py", line 32, in test_batch
summary(self.model,input_data=(lr,hr),depth=1)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 219, in summary
model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 300, in forward_pass
) from e
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []
can you help me? thanks.

snimu commented

It looks like your model's forward-method only takes one input, but you have given it two.

Here is the call that effectively happens inside summary, given your arguments:

# Setup:
model = Model()

# The call:
model(lr, hr)

You can see this from the following part of the error message: TypeError: forward() takes 2 positional arguments but 3 were given. The three arguments that were given are self, lr, and hr (self is automatically given). You have not provided code for your model, but it seems clear to me that your model's forward-pass only takes a single argument besides self.

snimu commented

@joaolcguerreiro Is your issue resolved?