SciML/SciMLSensitivity.jl

Error with complex numbers

seadra opened this issue · 0 comments

seadra commented

(Discovered in a different issue here)

The following code fails with an error when using SciMLSensitivity

using DifferentialEquations, DiffEqFlux, Zygote, SciMLSensitivity, Optimization, OptimizationFlux, OptimizationOptimJL, ComponentArrays, Lux, Random, LinearAlgebra

const T = 10.0;
const ω = π/T;

const id = Matrix{Complex{Float64}}(I,2, 2);
const u0 = id;


const utarget = Matrix{Complex{Float64}}([im 0; 0 -im]);

ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);

function f_nn(u, p, t)
    local a, _ = ann([t/T],p,st);
    local A = [a[1] 0.0; 0.0 -a[1]];
    return -(im*A)*u;
end



tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ComponentArray(ip));


function loss_adjoint(p)
    local prediction = solve(prob_ode, BS5(), p=p, abstol=1e-7, reltol=1e-7)
    local usol = last(prediction)
    local loss = abs(1.0 - abs(tr(usol*utarget')/2))
    return loss
end

opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray(ip));
optimized_sol_nn = Optimization.solve(opt_prob, AMSGrad(0.001), maxiters = 100, progress=true);

with the error message

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rYAAz/src/concrete_solve.jl:92

MethodError: no method matching default_relstep(::Nothing, ::Type{ComplexF64})

Closest candidates are:
  default_relstep(::Type, ::Any)
   @ FiniteDiff ~/.julia/packages/FiniteDiff/40JnL/src/epsilons.jl:25
  default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype, T<:Number}
   @ FiniteDiff ~/.julia/packages/FiniteDiff/40JnL/src/epsilons.jl:26

Importing the (deprecated) DiffEqSensitivity package instead of SciMLSensitivity works OK.