jeshraghian/snntorch

BNTT Layer missing forward function

Opened this issue · 0 comments

  • snntorch version: 0.9.1
  • Python version: 3.10.12
  • Operating System: Ubuntu 22.04

Description

I tried to implement BNTT layer by using BatchNormTT2d() according to documentation. However, I receive an error during forward pass. Here is the minimal example with received error.

What I Did

# Basic torch tools
import torch
import torch.nn as nn
import snntorch as snn

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.beta1 = torch.ones(64,44,44) * 0.7

        self.conv = nn.Conv2d(1,64, kernel_size=(7,7), stride=(2,2), padding=(3,3))
        self.bntt = snn.BatchNormTT2d(64, time_steps=30)
        self.lif1 = snn.Leaky(beta=self.beta1, learn_beta=True)

    def forward(self, x):
        mem = self.lif1.init_leaky()

        spk_rec = []
        mem_rec = []

        for step in range (x.shape[0]):
            cur1 = self.conv(x[step])
            cur1 = self.bntt(cur1)

            spk1, mem = self.lif1(cur1, mem)

            spk_rec.append(spk1)
            mem_rec.append(mem)

        return torch.stack(spk_rec), torch.stack(mem_rec)

model = Net()
# dimensions: time_stamps, batch_size, channels, height, width
inp = torch.rand(30,32,1,88,88)
__ , mem_rec = model.forward(inp)

print('No error, good job!')

Here is the error I receive:

raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
NotImplementedError: Module [ModuleList] is missing the required "forward" function