keras-team/keras-contrib

Pytorch code translation to Keras code

albertopolito opened this issue · 0 comments

Goodmorning,
I'm a beginner and this is the first time that I use Keras to implement a neural network.
I would write the same network of this link with the same activation function and forward mechanism.
I see that there is a tool that convert ONNX models to Keras models, but it seems that doesn't work fine with this code.
So I would translate it manually from Pytorch to Keras.
I have some questions:
-how can I write in Keras the activation function:

  class ActFun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        return grad_input * temp.float()   

-how can I write in Keras the membrane potential update mechanism:

def mem_update(ops, x, mem, spike):
   mem = mem * decay * (1. - spike) + ops(x)
   spike = act_fun(mem) # act_fun : approximation firing function
   return mem, spike

 class SCNN(nn.Module):

 ...

 def forward(self, input, time_window = 20):
       c1_mem = c1_spike = torch.zeros(batch_size, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0], device=device)
       c2_mem = c2_spike = torch.zeros(batch_size, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1], device=device)

       h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, cfg_fc[0], device=device)
       h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, cfg_fc[1], device=device)

       for step in range(time_window): # simulation time steps
           x = input > torch.rand(input.size(), device=device) # prob. firing

           c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)

           x = F.avg_pool2d(c1_spike, 2)

           c2_mem, c2_spike = mem_update(self.conv2,x, c2_mem,c2_spike)

           x = F.avg_pool2d(c2_spike, 2)
           x = x.view(batch_size, -1)

           h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
           h1_sumspike += h1_spike
           h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)
           h2_sumspike += h2_spike

       outputs = h2_sumspike / time_window
       return outputs

Thanks in advance for your time and your help.