sksq96/pytorch-summary

torchsummary does not work with user defined module

aqkfatmtvvfb opened this issue · 0 comments

code

import torch
from torch import nn
from torch.nn import functional as F
import torchsummary


class MLP(nn.Module):

    def __init__(self):

        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, X):

        return self.out(F.relu(self.hidden(X)))


class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):

            self._modules[str(idx)] = module

    def forward(self, X):

        for block in self._modules.values():
            X = block(X)
        return X


class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.rand_weight = torch.rand((20, 20), requires_grad=False)
        self.linear = nn.Linear(20, 20)

    def forward(self, X):
        X = self.linear(X)

        X = F.relu(torch.mm(X, self.rand_weight) + 1)

        X = self.linear(X)

        while X.abs().sum() > 1:
            X /= 2
        return X.sum()


class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),
                                 nn.Linear(64, 32), nn.ReLU())
        self.linear = nn.Linear(32, 16)

    def forward(self, X):
        return self.linear(self.net(X))


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':

    device = torch.device('cpu')
    chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())
    torchsummary.summary(chimera, (20,), device=device)
    print('parameters_count:', count_parameters(chimera))

Error message:

  File "C:\Users\wangyu2\anaconda3\Lib\site-packages\torchsummary\torchsummary.py", line 143, in summary
    raise RuntimeError(
RuntimeError: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [NestMLP: 1-1, Sequential: 2-1, Linear: 3-1, ReLU: 3-2, Linear: 3-3, ReLU: 3-4, Linear: 2-2, Linear: 
1-2, Linear: 2-3, Linear: 2-4]