Error with complex numbers
seadra opened this issue · 0 comments
seadra commented
(Discovered in a different issue here)
The following code fails with an error when using SciMLSensitivity
using DifferentialEquations, DiffEqFlux, Zygote, SciMLSensitivity, Optimization, OptimizationFlux, OptimizationOptimJL, ComponentArrays, Lux, Random, LinearAlgebra
const T = 10.0;
const ω = π/T;
const id = Matrix{Complex{Float64}}(I,2, 2);
const u0 = id;
const utarget = Matrix{Complex{Float64}}([im 0; 0 -im]);
ann = Lux.Chain(Lux.Dense(1,32), Lux.Dense(32,32,tanh), Lux.Dense(32,1));
rng = Random.default_rng();
ip, st = Lux.setup(rng, ann);
function f_nn(u, p, t)
local a, _ = ann([t/T],p,st);
local A = [a[1] 0.0; 0.0 -a[1]];
return -(im*A)*u;
end
tspan = (0.0, T)
prob_ode = ODEProblem(f_nn, u0, tspan, ComponentArray(ip));
function loss_adjoint(p)
local prediction = solve(prob_ode, BS5(), p=p, abstol=1e-7, reltol=1e-7)
local usol = last(prediction)
local loss = abs(1.0 - abs(tr(usol*utarget')/2))
return loss
end
opt_f = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), Optimization.AutoZygote());
opt_prob = Optimization.OptimizationProblem(opt_f, ComponentArray(ip));
optimized_sol_nn = Optimization.solve(opt_prob, AMSGrad(0.001), maxiters = 100, progress=true);
with the error message
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/rYAAz/src/concrete_solve.jl:92
MethodError: no method matching default_relstep(::Nothing, ::Type{ComplexF64})
Closest candidates are:
default_relstep(::Type, ::Any)
@ FiniteDiff ~/.julia/packages/FiniteDiff/40JnL/src/epsilons.jl:25
default_relstep(::Val{fdtype}, ::Type{T}) where {fdtype, T<:Number}
@ FiniteDiff ~/.julia/packages/FiniteDiff/40JnL/src/epsilons.jl:26
Importing the (deprecated) DiffEqSensitivity
package instead of SciMLSensitivity
works OK.