Non-exact zero sensitivities for GaussAdjoint
ArnoStrouwen opened this issue · 0 comments
ArnoStrouwen commented
dynamics = (x, _p, _t) -> x
function loss1(params)
u0 = zeros(2)
problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing=true)))
sum(Array(rollout)[:, end])
end
function loss2(params)
u0 = zeros(2)
problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))
sum(Array(rollout)[:, end])
end
julia> Zygote.gradient(loss1, zeros(123))[1]
123-element Vector{Float64}:
0.0
⋮
0.0
julia> Zygote.gradient(loss2, zeros(123))[1]
123-element Vector{Float64}:
6.93297084729736e-310
⋮
6.9332527492712e-310
One advantage GaussAdjoint
has, is that it does not need allow_nothing
, but activating it does not resolve the issue.