TylerYep/torchinfo

ERROR

neverstoplearn opened this issue · 4 comments

Describe the bug
Traceback (most recent call last):
File "tools/test.py", line 46, in
main()
File "tools/test.py", line 42, in main
runner()
File "tools/../vedastr/runners/test_runner.py", line 44, in call
self.test_batch(img, label)
File "tools/../vedastr/runners/test_runner.py", line 29, in test_batch
summary(self.model,[img,label_input])
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 216, in summary
input_data, input_size, batch_dim, device, dtypes
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 249, in process_input
correct_input_size = get_correct_input_sizes(input_size)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 524, in get_correct_input_sizes
if not input_size or any(size <= 0 for size in flatten(input_size)):
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

To Reproduce
class TestRunner(InferenceRunner):
def init(self, test_cfg, deploy_cfg, common_cfg=None):
super(TestRunner, self).init(deploy_cfg, common_cfg)

    self.test_dataloader = self._build_dataloader(test_cfg['data'])
    if not isinstance(self.test_dataloader, dict):
        self.test_dataloader = dict(all=self.test_dataloader)
    self.postprocess_cfg = test_cfg.get('postprocess_cfg', None)

def test_batch(self, img, label):

    self.model.eval()

    with torch.no_grad():
        label_input, label_length, label_target = self.converter.test_encode(label)
        if self.use_gpu:
            img = img.cuda()
            #print(img.shape)
            label_input = label_input.cuda()

        if self.need_text:
            pred = self.model((img, label_input))
            from torchinfo import summary
            summary(self.model,[img,label_input])
        else:
            pred = self.model((img,))

        pred, prob = self.postprocess(pred, self.postprocess_cfg)
        self.metric.measure(pred, prob, label)
        self.backup_metric.measure(pred, prob, label)

Expected behavior
get the true output

Additional context
Add any other context about the problem here.

when i use summary(self.mdoel,input_size=(img,label_input)),I got this error:
`Traceback (most recent call last):
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 288, in forward_pass
_ = model.to(device)(*x, **kwargs)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "tools/test.py", line 46, in
main()
File "tools/test.py", line 42, in main
runner()
File "tools/../vedastr/runners/test_runner.py", line 44, in call
self.test_batch(img, label)
File "tools/../vedastr/runners/test_runner.py", line 29, in test_batch
summary(self.model,input_data=(img,label_input),depth=1)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 219, in summary
model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 300, in forward_pass
) from e
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []`

snimu commented

You are using input_size where you should be using input_data. input_size is meant for shapes, such that summary can automatically construct input-tensors of that size, while input_data is meant for directly giving the model data.

snimu commented

@neverstoplearn Has your issue been resolved? If it hasn't could you provide more context on the precise model that was used?

@neverstoplearn Has your issue been resolved? If it hasn't could you provide more context on the precise model that was used?

solved,thanks.