Nothing handling for GaussAdjoint
ArnoStrouwen opened this issue · 0 comments
ArnoStrouwen commented
using SciMLSensitivity
using OrdinaryDiffEq
using Zygote
dynamics = (x, _p, _t) -> x
function loss1(params)
u0 = zeros(2)
problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing=true)))
sum(Array(rollout)[:, end])
end
function loss2(params)
u0 = zeros(2)
problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP(allow_nothing=true)))
sum(Array(rollout)[:, end])
end
julia> Zygote.gradient(loss1, nothing)
(nothing,)
julia> Zygote.gradient(loss2, nothing)
ERROR: StackOverflowError:
Stacktrace:
[1] fmap(::Function, ::Nothing; exclude::Function, walk::Functors.DefaultWalk, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:7
[2] fmap(::Function, ::Nothing)
@ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:3
[3] allocate_zeros(x::Nothing)
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/parameters_handling.jl:79
[4] (::Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)})(::Function, ::Nothing)
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:106
[5] (::Functors.CachedWalk{Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)}, Functors.NoKeyword})(::Function, ::Nothing)
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:146
[6] execute(::Functors.CachedWalk{Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)}, Functors.NoKeyword}, ::Nothing)
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:38
[7] fmap(::Function, ::Nothing; exclude::Function, walk::Functors.DefaultWalk, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:11
--- the last 6 lines are repeated 7996 more times ---
[47984] fmap(::Function, ::Nothing)
@ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:3
[47985] allocate_zeros(x::Nothing)
@ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/parameters_handling.jl:79
[47986] (::Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)})(::Function, ::Nothing)
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:106