Define an activation function using a differential equation
RohanRajagopal opened this issue · 4 comments
Is there a way to define an activation function using a differential equation to train a recurrent neural network, (sequence
For instance use the following differential equation as activation function, where input to it is
function hopf_oscillator(du, u, p, t)
@unpack ω_h, μ, ω_ext, I0, W = p
signal = repeat([cos.(ω_ext * t)], 10)
Iext = W' * signal
du[1] = u[1] * (μ - u[1]^2) + I0 * Iext * cos(u[2])
du[2] = ω_h - I0 * cos(ω_ext*t) * sin(u[2])/u[1]
end
u0_h = [.1, .1]
p_h = ComponentArray(ω_h = 1., μ = 1., ω_ext = 1, I0 = 1, W = rand(10))
prob_h = ODEProblem(hopf_oscillator, u0_h, dt = 0.1, tspan, p_h)
The gradients can be calculated using the following, I think.
dp_h = Zygote.gradients(objective_function, p_h)
(or perhaps, dp_h = Zygote.jacobian(objective_function, p_h)
)
Calculate
While the question posited is still vague I suppose, any nudge in the direction towards achieving this is deeply appreciated. Thank you very much.
What did you try? I don't see what the issue would be.
This is what i have currently, if I have a single node with with
tspan = (0, 200)
time_steps = collect(1:1:10)
function hopf_oscillator(du, u, p, t)
@unpack W = p
ω_h = 1.
μ = 1.
ω_ext = 1
I0 = 0.1
w_rand = rand(5)
signal = [cos.(i*t) for i in 1:5]
Iext = W' * signal
du[1] = u[1] * (μ - u[1]^2) + I0 * Iext * cos(u[2])
du[2] = ω_h - I0 * Iext * sin(u[2])/u[1]
end
I understand that the feedforward 'signal' reaching this node have to be interpolated to be made compatible with the solver.
u0_h = [.1, .1]
p_h = ComponentArray(W = rand(5))
prob_h = ODEProblem(hopf_oscillator, u0_h, dt = 0.1, tspan, p_h)
function gradients_hopf_with_loss_function(p)
_prob = remake(prob_h, u0 = u0_h, tspan = (0.0, 20.0), p = p)
sol = solve(_prob, Rosenbrock23(), saveat = time_steps)
theta = sol[2,:]
D = cos.(time_steps)
r = sol[1, end]
x = r*cos.(theta) .- D
sum(abs2, x) #loss function
end
for i in 1:50
dp_h = Zygote.gradient(gradients_hopf_with_loss_function, p_h)
for j in 1:length(p_h)
p_h[j] = p_h[j] - 0.1 * dp_h[1][j]
end
end
And very much so the loss reduces. Is this right?
Thank you very much for your time.
I don't get the question. That is working right?
Yes! I just wanted to confirm if I was doing it right. Thank you very much.
Now I need to dress it in Lux if thats possible.
Thank you very much.