LSTM cannot be used
z-a-f opened this issue · 0 comments
z-a-f commented
If a model has an LSTM, this fails. I guess this is related to #130.
Minimum failing example:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(5, 5)
def forward(self, x):
return self.lstm(x)
model = Model()
summary(model, (3, 5), device='cpu')
throws an error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-87-26dd936cf377> in <module>
8
9 model = Model()
---> 10 summary(model, (3, 5), device='cpu')
~/miniconda3/envs/pytorch-dev/lib/python3.6/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
70 # make a forward pass
71 # print(x.shape)
---> 72 model(*x)
73
74 # remove these hooks
~/Git/pytorch-dev/pytorch/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-87-26dd936cf377> in forward(self, x)
5
6 def forward(self, x):
----> 7 return self.lstm(x)
8
9 model = Model()
~/Git/pytorch-dev/pytorch/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
729 _global_forward_hooks.values(),
730 self._forward_hooks.values()):
--> 731 hook_result = hook(self, input, result)
732 if hook_result is not None:
733 result = hook_result
~/miniconda3/envs/pytorch-dev/lib/python3.6/site-packages/torchsummary/torchsummary.py in hook(module, input, output)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:
~/miniconda3/envs/pytorch-dev/lib/python3.6/site-packages/torchsummary/torchsummary.py in <listcomp>(.0)
21 if isinstance(output, (list, tuple)):
22 summary[m_key]["output_shape"] = [
---> 23 [-1] + list(o.size())[1:] for o in output
24 ]
25 else:
AttributeError: 'tuple' object has no attribute 'size'