SciML/SciMLSensitivity.jl

Complex numbers handling for GaussAdjoint

ArnoStrouwen opened this issue · 0 comments

using SciMLSensitivity
using OrdinaryDiffEq
using Zygote
dynamics = (x, p, t) -> -x*p[1]

function loss1(params)
    u0 = [1.0+1.0im]
    problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
    rollout = solve(problem, Tsit5(), u0 = u0, p = params,
        sensealg = InterpolatingAdjoint(autodiff=false, autojacvec=false))
    abs(sum(Array(rollout)[:, end]))
end
function loss2(params)
    u0 = [1.0+1.0im]
    problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
    rollout = solve(problem, Tsit5(), u0 = u0, p = params,
        sensealg = GaussAdjoint(autodiff=false, autojacvec=false))
    abs(sum(Array(rollout)[:, end]))
end
params = ones(1)
julia> Zygote.gradient(loss1, params)
([-0.5202599109305424],)

julia> Zygote.gradient(loss2, params)
ERROR: MethodError: no method matching (::SciMLBase.ParamJacobianWrapper{false, ODEFunction{…}, Float64, Vector{…}})(::Vector{ComplexF64}, ::Vector{Float64})

Closest candidates are:
  (::SciMLBase.ParamJacobianWrapper{false})(::Any)
   @ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/function_wrappers.jl:87

Stacktrace:
  [1] finite_difference_jacobian!(J::Matrix{…}, f::SciMLBase.ParamJacobianWrapper{…}, x::Vector{…}, cache::FiniteDiff.JacobianCache{…}, f_in::Nothing; relstep::Float64, absstep::Float64, colorvec::UnitRange{…}, sparsity::Nothing, dir::Bool)
    @ FiniteDiff ~/.julia/packages/FiniteDiff/BgLLM/src/jacobians.jl:438
  [2] finite_difference_jacobian!(J::Matrix{ComplexF64}, f::Function, x::Vector{Float64}, cache::FiniteDiff.JacobianCache{Vector{…}, Vector{…}, Vector{…}, Vector{…}, UnitRange{…}, Nothing, Val{…}(), ComplexF64}, f_in::Nothing)
    @ FiniteDiff ~/.julia/packages/FiniteDiff/BgLLM/src/jacobians.jl:341
  [3] finite_difference_jacobian!
    @ FiniteDiff ~/.julia/packages/FiniteDiff/BgLLM/src/jacobians.jl:341 [inlined]
  [4] jacobian!(J::Matrix{…}, f::Function, x::Vector{…}, fx::Vector{…}, alg::GaussAdjoint{…}, jac_config::FiniteDiff.JacobianCache{…})
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/derivative_wrappers.jl:157
  [5] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:447
  [6] GaussIntegrand
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:490 [inlined]
  [7] (::SciMLSensitivity.var"#255#256"{})(out::Vector{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:517
  [8] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/uVI0B/src/integrating_sum.jl:56
  [9] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/callbacks.jl:605 [inlined]
 [10] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/callbacks.jl:617 [inlined]
 [11] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/integrator_utils.jl:346
 [12] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/integrator_utils.jl:253
 [13] loopfooter!
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/integrator_utils.jl:206 [inlined]
 [14] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:538
 [15] #__solve#746
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:6 [inlined]
 [16] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:1 [inlined]
 [17] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:609
 [18] solve_call
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:567 [inlined]
 [19] #solve_up#42
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1058 [inlined]
 [20] solve_up
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1044 [inlined]
 [21] #solve#40
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
 [22] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:536
 [23] _adjoint_sensitivities
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:503 [inlined]
 [24] #adjoint_sensitivities#63
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/sensitivity_interface.jl:386 [inlined]
 [25] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{@Kwargs{}, Tsit5{}, GaussAdjoint{}, Vector{}, Vector{}, SciMLBase.ChainRulesOriginator, Tuple{}, Colon, @NamedTuple{}})(Δ::Matrix{ComplexF64})
    @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/concrete_solve.jl:515
 [26] ZBack
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [27] (::Zygote.var"#291#292"{Tuple{NTuple{}, Tuple{}}, Zygote.ZBack{SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{}}})(Δ::Matrix{ComplexF64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206
 [28] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{NTuple{}, Tuple{}}, Zygote.ZBack{SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{}}}})(Δ::Matrix{ComplexF64})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [29] #solve#40
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
 [30] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#40", GaussAdjoint{…}, Vector{…}, Vector{…}, Val{…}, @Kwargs{}, typeof(solve), ODEProblem{…}, Tsit5{…}}, Any})(Δ::Matrix{ComplexF64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] #291
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [32] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{NTuple{}, Tuple{}}, Zygote.Pullback{Tuple{}, Any}}})(Δ::Matrix{ComplexF64})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [33] solve
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:971 [inlined]
 [34] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{…}, typeof(solve), ODEProblem{…}, Tsit5{…}}, Any})(Δ::Matrix{ComplexF64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [35] loss2
    @ ./REPL[7]:4 [inlined]
 [36] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [37] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [38] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [39] top-level scope
    @ REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.