torchsummary does not work with user defined module
aqkfatmtvvfb opened this issue · 0 comments
aqkfatmtvvfb commented
code
import torch
from torch import nn
from torch.nn import functional as F
import torchsummary
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.out = nn.Linear(256, 10)
def forward(self, X):
return self.out(F.relu(self.hidden(X)))
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
for idx, module in enumerate(args):
self._modules[str(idx)] = module
def forward(self, X):
for block in self._modules.values():
X = block(X)
return X
class FixedHiddenMLP(nn.Module):
def __init__(self):
super().__init__()
self.rand_weight = torch.rand((20, 20), requires_grad=False)
self.linear = nn.Linear(20, 20)
def forward(self, X):
X = self.linear(X)
X = F.relu(torch.mm(X, self.rand_weight) + 1)
X = self.linear(X)
while X.abs().sum() > 1:
X /= 2
return X.sum()
class NestMLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),
nn.Linear(64, 32), nn.ReLU())
self.linear = nn.Linear(32, 16)
def forward(self, X):
return self.linear(self.net(X))
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == '__main__':
device = torch.device('cpu')
chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())
torchsummary.summary(chimera, (20,), device=device)
print('parameters_count:', count_parameters(chimera))
Error message:
File "C:\Users\wangyu2\anaconda3\Lib\site-packages\torchsummary\torchsummary.py", line 143, in summary
raise RuntimeError(
RuntimeError: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [NestMLP: 1-1, Sequential: 2-1, Linear: 3-1, ReLU: 3-2, Linear: 3-3, ReLU: 3-4, Linear: 2-2, Linear:
1-2, Linear: 2-3, Linear: 2-4]