SciML/SciMLSensitivity.jl

Nothing handling for GaussAdjoint

ArnoStrouwen opened this issue · 0 comments

using SciMLSensitivity
using OrdinaryDiffEq
using Zygote
dynamics = (x, _p, _t) -> x

function loss1(params)
    u0 = zeros(2)
    problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
    rollout = solve(problem, Tsit5(), u0 = u0, p = params,
        sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP(allow_nothing=true)))
    sum(Array(rollout)[:, end])
end
function loss2(params)
    u0 = zeros(2)
    problem = ODEProblem(dynamics, u0, (0.0, 1.0), params)
    rollout = solve(problem, Tsit5(), u0 = u0, p = params,
        sensealg = GaussAdjoint(autojacvec = ZygoteVJP(allow_nothing=true)))
    sum(Array(rollout)[:, end])
end
julia> Zygote.gradient(loss1, nothing)
(nothing,)

julia> Zygote.gradient(loss2, nothing)
ERROR: StackOverflowError:
Stacktrace:
     [1] fmap(::Function, ::Nothing; exclude::Function, walk::Functors.DefaultWalk, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
       @ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:7
     [2] fmap(::Function, ::Nothing)
       @ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:3
     [3] allocate_zeros(x::Nothing)
       @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/parameters_handling.jl:79
     [4] (::Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)})(::Function, ::Nothing)
       @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:106
     [5] (::Functors.CachedWalk{Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)}, Functors.NoKeyword})(::Function, ::Nothing)
       @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:146
     [6] execute(::Functors.CachedWalk{Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)}, Functors.NoKeyword}, ::Nothing)
       @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:38
     [7] fmap(::Function, ::Nothing; exclude::Function, walk::Functors.DefaultWalk, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
       @ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:11
--- the last 6 lines are repeated 7996 more times ---
 [47984] fmap(::Function, ::Nothing)
       @ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:3
 [47985] allocate_zeros(x::Nothing)
       @ SciMLSensitivity ~/SciML/SciMLSensitivity.jl/src/parameters_handling.jl:79
 [47986] (::Functors.ExcludeWalk{Functors.DefaultWalk, typeof(SciMLSensitivity.allocate_zeros), typeof(Functors.isleaf)})(::Function, ::Nothing)
       @ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:106