Complex numbers handling for GaussAdjoint
ArnoStrouwen opened this issue · 0 comments
ArnoStrouwen commented
using SciMLSensitivity
using OrdinaryDiffEq
using Zygote
dynamics = (x, p, t) -> -x*p[1]
function loss1(params)
u0 = [1.0+1.0im]
problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = InterpolatingAdjoint(autodiff=false, autojacvec=false))
abs(sum(Array(rollout)[:, end]))
end
function loss2(params)
u0 = [1.0+1.0im]
problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = GaussAdjoint(autodiff=false, autojacvec=false))
abs(sum(Array(rollout)[:, end]))
end
params = ones(1)
julia> Zygote.gradient(loss1, params)
([-0.5202599109305424],)
julia> Zygote.gradient(loss2, params)
ERROR: MethodError: no method matching (::SciMLBase.ParamJacobianWrapper{false, ODEFunction{…}, Float64, Vector{…}})(::Vector{ComplexF64}, ::Vector{Float64})
Closest candidates are:
(::SciMLBase.ParamJacobianWrapper{false})(::Any)
@ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/function_wrappers.jl:87
Stacktrace:
[1] finite_difference_jacobian!(J::Matrix{…}, f::SciMLBase.ParamJacobianWrapper{…}, x::Vector{…}, cache::FiniteDiff.JacobianCache{…}, f_in::Nothing; relstep::Float64, absstep::Float64, colorvec::UnitRange{…}, sparsity::Nothing, dir::Bool)
@ FiniteDiff ~/.julia/packages/FiniteDiff/BgLLM/src/jacobians.jl:438
[2] finite_difference_jacobian!(J::Matrix{ComplexF64}, f::Function, x::Vector{Float64}, cache::FiniteDiff.JacobianCache{Vector{…}, Vector{…}, Vector{…}, Vector{…}, UnitRange{…}, Nothing, Val{…}(), ComplexF64}, f_in::Nothing)
@ FiniteDiff ~/.julia/packages/FiniteDiff/BgLLM/src/jacobians.jl:341
[3] finite_difference_jacobian!
@ FiniteDiff ~/.julia/packages/FiniteDiff/BgLLM/src/jacobians.jl:341 [inlined]
[4] jacobian!(J::Matrix{…}, f::Function, x::Vector{…}, fx::Vector{…}, alg::GaussAdjoint{…}, jac_config::FiniteDiff.JacobianCache{…})
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/derivative_wrappers.jl:157
[5] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:447
[6] GaussIntegrand
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:490 [inlined]
[7] (::SciMLSensitivity.var"#255#256"{…})(out::Vector{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:517
[8] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/uVI0B/src/integrating_sum.jl:56
[9] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/eLhx9/src/callbacks.jl:605 [inlined]
[10] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/eLhx9/src/callbacks.jl:617 [inlined]
[11] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/integrator_utils.jl:346
[12] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/integrator_utils.jl:253
[13] loopfooter!
@ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/integrator_utils.jl:206 [inlined]
[14] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:538
[15] #__solve#746
@ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:6 [inlined]
[16] __solve
@ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:1 [inlined]
[17] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:609
[18] solve_call
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:567 [inlined]
[19] #solve_up#42
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1058 [inlined]
[20] solve_up
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1044 [inlined]
[21] #solve#40
@ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
[22] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:536
[23] _adjoint_sensitivities
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:503 [inlined]
[24] #adjoint_sensitivities#63
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/sensitivity_interface.jl:386 [inlined]
[25] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{@Kwargs{}, Tsit5{…}, GaussAdjoint{…}, Vector{…}, Vector{…}, SciMLBase.ChainRulesOriginator, Tuple{}, Colon, @NamedTuple{}})(Δ::Matrix{ComplexF64})
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/concrete_solve.jl:515
[26] ZBack
@ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
[27] (::Zygote.var"#291#292"{Tuple{NTuple{…}, Tuple{…}}, Zygote.ZBack{SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…}}})(Δ::Matrix{ComplexF64})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206
[28] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{NTuple{…}, Tuple{…}}, Zygote.ZBack{SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…}}}})(Δ::Matrix{ComplexF64})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[29] #solve#40
@ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
[30] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#40", GaussAdjoint{…}, Vector{…}, Vector{…}, Val{…}, @Kwargs{}, typeof(solve), ODEProblem{…}, Tsit5{…}}, Any})(Δ::Matrix{ComplexF64})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[31] #291
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
[32] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{NTuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Any}}})(Δ::Matrix{ComplexF64})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[33] solve
@ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:971 [inlined]
[34] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{…}, typeof(solve), ODEProblem{…}, Tsit5{…}}, Any})(Δ::Matrix{ComplexF64})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[35] loss2
@ ./REPL[7]:4 [inlined]
[36] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[37] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
[38] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
[39] top-level scope
@ REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.