SciML/SciMLSensitivity.jl

Non-exact zero sensitivities for GaussAdjoint

ArnoStrouwen opened this issue · 0 comments

dynamics = (x, _p, _t) -> x
function loss1(params)
    u0 = zeros(2)
    problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
    rollout = solve(problem, Tsit5(), u0 = u0, p = params,
        sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing=true)))
    sum(Array(rollout)[:, end])
end
function loss2(params)
    u0 = zeros(2)
    problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
    rollout = solve(problem, Tsit5(), u0 = u0, p = params,
        sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))
    sum(Array(rollout)[:, end])
end
julia> Zygote.gradient(loss1, zeros(123))[1]
123-element Vector{Float64}:
 0.0
 
 0.0

julia> Zygote.gradient(loss2, zeros(123))[1]
123-element Vector{Float64}:
 6.93297084729736e-310
 
 6.9332527492712e-310

One advantage GaussAdjoint has, is that it does not need allow_nothing, but activating it does not resolve the issue.