Pytorch code translation to Keras code
albertopolito opened this issue · 0 comments
albertopolito commented
Goodmorning,
I'm a beginner and this is the first time that I use Keras to implement a neural network.
I would write the same network of this link with the same activation function and forward mechanism.
I see that there is a tool that convert ONNX models to Keras models, but it seems that doesn't work fine with this code.
So I would translate it manually from Pytorch to Keras.
I have some questions:
-how can I write in Keras the activation function:
class ActFun(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.gt(thresh).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input - thresh) < lens
return grad_input * temp.float()
-how can I write in Keras the membrane potential update mechanism:
def mem_update(ops, x, mem, spike):
mem = mem * decay * (1. - spike) + ops(x)
spike = act_fun(mem) # act_fun : approximation firing function
return mem, spike
class SCNN(nn.Module):
...
def forward(self, input, time_window = 20):
c1_mem = c1_spike = torch.zeros(batch_size, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0], device=device)
c2_mem = c2_spike = torch.zeros(batch_size, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1], device=device)
h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, cfg_fc[0], device=device)
h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, cfg_fc[1], device=device)
for step in range(time_window): # simulation time steps
x = input > torch.rand(input.size(), device=device) # prob. firing
c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike)
x = F.avg_pool2d(c1_spike, 2)
c2_mem, c2_spike = mem_update(self.conv2,x, c2_mem,c2_spike)
x = F.avg_pool2d(c2_spike, 2)
x = x.view(batch_size, -1)
h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
h1_sumspike += h1_spike
h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)
h2_sumspike += h2_spike
outputs = h2_sumspike / time_window
return outputs
Thanks in advance for your time and your help.