ReverseDiff leads to wrong sensitivities over ODESolution
ThummeTo opened this issue · 3 comments
Hi,
Problem
as soon as ReverseDiff is used to determine sensitivities over an ODESolution, the sensitivities are numerically wrong.
ForwardDiff, Zygote and FiniteDiff return the same (correct) sensitivities. However, the ReverseDiff-run returns the correct ODE solution, see the generated plot.
I did not dig in, so maybe this is also more a ReverseDiff issue than SciMLSensitivity...
Expected behavior
Correct sensitivity determination (like for ForwardDiff, Zygote and FiniteDiff).
Minimal Reproducible Example 👇
Just an ODE describing an object accelerated by gravity ...
using DifferentialEquations
using SciMLSensitivity.SciMLBase: RightRootFind
using ForwardDiff, ReverseDiff, Zygote, FiniteDiff
using Plots
# A falling ball (without contact, just gravity)
GRAVITY = 9.81
MASS = 1.0
NUM_STATES = 2
t_start = 0.0
t_step = 0.05
t_stop = 2.0
u0 = [1.0, 0.0] # start state: ball position (1.0) and velocity (0.0)
p = [GRAVITY, MASS]
solver = Rosenbrock23()
# setup BouncingBallODE
function fx(u, p, t)
g, m = p
return [u[2], -g]
end
ff = ODEFunction{false}(fx)
prob = ODEProblem{false}(ff, u0, (t_start, t_stop), p)
global us
function mysolve(p)
global us
solution = solve(prob; p=p, alg=solver, saveat=t_start:t_step:t_stop)
us = solution
# fix for ReverseDiff not returning an ODESolution
if !isa(us, ReverseDiff.TrackedArray)
us = collect(u[1] for u in solution.u)
else
us = solution[1,:]
end
return us
end
function loss(p)
us = mysolve(p)
return sum(abs.(us))
end
fig = plot()
# some gradients
grad_fi = FiniteDiff.finite_difference_gradient(loss, p) # nice!
plot!(fig, us; label="FI")
grad_fd = ForwardDiff.gradient(loss, p) # nice!
plot!(fig, ForwardDiff.value.(us); label="FD")
grad_zg = Zygote.gradient(loss, p)[1] # nice!
plot!(fig, us; label="ZG")
grad_rd = ReverseDiff.gradient(loss, p) # this is wrong!
plot!(fig, ReverseDiff.value(us); label="RD")
Error
No error, just wrong sensitivities for grad_rd
...
julia> grad_fi
2-element Vector{Float64}:
26.962499999817844
0.0
julia> grad_fd
2-element Vector{Float64}:
26.962500000000002
0.0
julia> grad_zg
2-element Vector{Float64}:
26.9625
0.0
julia> grad_rd
2-element Vector{Float64}:
67.9625
0.0
Environment (please complete the following information):
Julia v1.9.3
The important packages from the environment:
Status `C:\Users\...\.julia\environments\v1.9\Project.toml`
[2b5f629d] DiffEqBase v6.141.0
[459566f4] DiffEqCallbacks v2.34.0
[0c46a032] DifferentialEquations v7.11.0
[6a86dc24] FiniteDiff v2.21.1
[f6369f11] ForwardDiff v0.10.36
[37e2e3b7] ReverseDiff v1.15.1
[1ed8b502] SciMLSensitivity v7.47.0
[e88e6eb3] Zygote v0.6.67
** Thanks in advance! **
This doesn't occur anymore in the release(s) with:
Status `C:\Users\...\.julia\environments\v1.9\Project.toml`
[2b5f629d] DiffEqBase v6.143.0
[459566f4] DiffEqCallbacks v2.35.0
[0c46a032] DifferentialEquations v7.11.0
[6a86dc24] FiniteDiff v2.21.1
[f6369f11] ForwardDiff v0.10.36
[37e2e3b7] ReverseDiff v1.15.1
[1ed8b502] SciMLSensitivity v7.47.0
[e88e6eb3] Zygote v0.6.67
So this was more related to DiffEqBase or DiffEqCallback?
Thanks in any case!
I think it's related to the RAT v3 things. Can you add a test for this so it doesn't regress?
Yes, I can open a PR.
What's RAT v3?