SciML/SciMLSensitivity.jl

ReverseDiff leads to wrong sensitivities over ODESolution

ThummeTo opened this issue · 3 comments

Hi,

Problem

as soon as ReverseDiff is used to determine sensitivities over an ODESolution, the sensitivities are numerically wrong.
ForwardDiff, Zygote and FiniteDiff return the same (correct) sensitivities. However, the ReverseDiff-run returns the correct ODE solution, see the generated plot.

I did not dig in, so maybe this is also more a ReverseDiff issue than SciMLSensitivity...

Expected behavior

Correct sensitivity determination (like for ForwardDiff, Zygote and FiniteDiff).

Minimal Reproducible Example 👇

Just an ODE describing an object accelerated by gravity ...

using DifferentialEquations
using SciMLSensitivity.SciMLBase: RightRootFind
using ForwardDiff, ReverseDiff, Zygote, FiniteDiff
using Plots

# A falling ball (without contact, just gravity)
GRAVITY = 9.81
MASS = 1.0
NUM_STATES = 2

t_start = 0.0
t_step = 0.05
t_stop = 2.0
u0 = [1.0, 0.0] # start state: ball position (1.0) and velocity (0.0)
p = [GRAVITY, MASS]

solver = Rosenbrock23()

# setup BouncingBallODE 
function fx(u, p, t)
    g, m = p
    return [u[2], -g] 
end

ff = ODEFunction{false}(fx) 
prob = ODEProblem{false}(ff, u0, (t_start, t_stop), p)

global us 
function mysolve(p)
    global us 

    solution = solve(prob; p=p, alg=solver, saveat=t_start:t_step:t_stop)

    us = solution

    # fix for ReverseDiff not returning an ODESolution
    if !isa(us, ReverseDiff.TrackedArray)
        us = collect(u[1] for u in solution.u)
    else
        us = solution[1,:]
    end

    return us
end

function loss(p)
    us = mysolve(p)
    return sum(abs.(us))
end 

fig = plot()

# some gradients
grad_fi = FiniteDiff.finite_difference_gradient(loss, p)    # nice!
plot!(fig, us; label="FI")
grad_fd = ForwardDiff.gradient(loss, p)                     # nice!
plot!(fig, ForwardDiff.value.(us); label="FD")
grad_zg = Zygote.gradient(loss, p)[1]                       # nice!
plot!(fig, us; label="ZG")
grad_rd = ReverseDiff.gradient(loss, p)                     # this is wrong!
plot!(fig, ReverseDiff.value(us); label="RD")

MWE_ReverseDiff

Error

No error, just wrong sensitivities for grad_rd ...

julia> grad_fi
2-element Vector{Float64}:
 26.962499999817844
  0.0

julia> grad_fd
2-element Vector{Float64}:
 26.962500000000002
  0.0

julia> grad_zg
2-element Vector{Float64}:
 26.9625
  0.0

julia> grad_rd
2-element Vector{Float64}:
 67.9625
  0.0

Environment (please complete the following information):

Julia v1.9.3

The important packages from the environment:

Status `C:\Users\...\.julia\environments\v1.9\Project.toml`
  [2b5f629d] DiffEqBase v6.141.0
  [459566f4] DiffEqCallbacks v2.34.0
  [0c46a032] DifferentialEquations v7.11.0
  [6a86dc24] FiniteDiff v2.21.1
  [f6369f11] ForwardDiff v0.10.36
  [37e2e3b7] ReverseDiff v1.15.1
  [1ed8b502] SciMLSensitivity v7.47.0
  [e88e6eb3] Zygote v0.6.67

** Thanks in advance! **

This doesn't occur anymore in the release(s) with:

Status `C:\Users\...\.julia\environments\v1.9\Project.toml`
  [2b5f629d] DiffEqBase v6.143.0
  [459566f4] DiffEqCallbacks v2.35.0
  [0c46a032] DifferentialEquations v7.11.0
  [6a86dc24] FiniteDiff v2.21.1
  [f6369f11] ForwardDiff v0.10.36
  [37e2e3b7] ReverseDiff v1.15.1
  [1ed8b502] SciMLSensitivity v7.47.0
  [e88e6eb3] Zygote v0.6.67

So this was more related to DiffEqBase or DiffEqCallback?
Thanks in any case!

I think it's related to the RAT v3 things. Can you add a test for this so it doesn't regress?

Yes, I can open a PR.

What's RAT v3?