Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchsummary does not work with user defined module #205

Open
aqkfatmtvvfb opened this issue Jul 24, 2024 · 0 comments
Open

torchsummary does not work with user defined module #205

aqkfatmtvvfb opened this issue Jul 24, 2024 · 0 comments

Comments

@aqkfatmtvvfb
Copy link

code

import torch
from torch import nn
from torch.nn import functional as F
import torchsummary


class MLP(nn.Module):

    def __init__(self):

        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.out = nn.Linear(256, 10)

    def forward(self, X):

        return self.out(F.relu(self.hidden(X)))


class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):

            self._modules[str(idx)] = module

    def forward(self, X):

        for block in self._modules.values():
            X = block(X)
        return X


class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.rand_weight = torch.rand((20, 20), requires_grad=False)
        self.linear = nn.Linear(20, 20)

    def forward(self, X):
        X = self.linear(X)

        X = F.relu(torch.mm(X, self.rand_weight) + 1)

        X = self.linear(X)

        while X.abs().sum() > 1:
            X /= 2
        return X.sum()


class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),
                                 nn.Linear(64, 32), nn.ReLU())
        self.linear = nn.Linear(32, 16)

    def forward(self, X):
        return self.linear(self.net(X))


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':

    device = torch.device('cpu')
    chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())
    torchsummary.summary(chimera, (20,), device=device)
    print('parameters_count:', count_parameters(chimera))

Error message:

  File "C:\Users\wangyu2\anaconda3\Lib\site-packages\torchsummary\torchsummary.py", line 143, in summary
    raise RuntimeError(
RuntimeError: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [NestMLP: 1-1, Sequential: 2-1, Linear: 3-1, ReLU: 3-2, Linear: 3-3, ReLU: 3-4, Linear: 2-2, Linear: 
1-2, Linear: 2-3, Linear: 2-4]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant