SciML/SciMLSensitivity.jl

SciMLSensitivity 7.63 leads to scalar indexing on CUDA

Closed this issue · 1 comments

NeuralODE test in Lux started failing after the update https://buildkite.com/julialang/lux-dot-jl/builds/3090#01909095-3f90-4a59-953b-72852e5fd962/153-315

Problem seems to originate from 14eff95

cc @DhairyaLGandhi since you made that change

Stacktrace

RROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations do not execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use allowscalar or @allowscalar
to enable scalar iteration globally or for the operations in question.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] errorscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
[3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
[4] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
[5] getindex
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/indexing.jl:48 [inlined]
[6] scalar_getindex
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/indexing.jl:34 [inlined]
[7] _getindex
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/indexing.jl:17 [inlined]
[8] getindex
@ ~/.julia/packages/GPUArrays/8Y80U/src/host/indexing.jl:15 [inlined]
[9] GetStateIndex
@ ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/state_indexing.jl:41 [inlined]
[10] AbstractStateGetIndexer
@ ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/value_provider_interface.jl:157 [inlined]
[11] CallWith
@ ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/value_provider_interface.jl:208 [inlined]
[12] iterate
@ ./generator.jl:48 [inlined]
[13] _collect(c::Vector{SymbolicIndexingInterface.GetStateIndex{CartesianIndex{…}}}, itr::Base.Generator{Vector{SymbolicIndexingInterface.GetStateIndex{…}}, SymbolicIndexingInterface.CallWith{Tuple{…}}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:800
[14] collect_similar
@ ./array.jl:709 [inlined]
[15] map
@ ./abstractarray.jl:3374 [inlined]
[16] MultipleGetters
@ ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/state_indexing.jl:173 [inlined]
[17] MultipleGetters
@ ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/state_indexing.jl:165 [inlined]
[18] _broadcast_getindex_evalf
@ ./broadcast.jl:673 [inlined]
[19] _broadcast_getindex
@ ./broadcast.jl:646 [inlined]
[20] getindex
@ ./broadcast.jl:605 [inlined]
[21] macro expansion
@ ./broadcast.jl:968 [inlined]
[22] macro expansion
@ ./simdloop.jl:77 [inlined]
[23] copyto!
@ ./broadcast.jl:967 [inlined]
[24] copyto!
@ ./broadcast.jl:920 [inlined]
[25] copy
@ ./broadcast.jl:892 [inlined]
[26] materialize
@ ./broadcast.jl:867 [inlined]
[27] (::SymbolicIndexingInterface.MultipleGetters{…})(ts::SymbolicIndexingInterface.Timeseries, ::SymbolicIndexingInterface.IndexerBoth, prob::ODESolution{…}, i::UnitRange{…})
@ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/state_indexing.jl:184
[28] MultipleGetters
@ ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/state_indexing.jl:165 [inlined]
[29] (::SymbolicIndexingInterface.MultipleGetters{…})(prob::ODESolution{…}, i::UnitRange{…})
@ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/Xc8In/src/value_provider_interface.jl:157
[30] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::InterpolatingAdjoint{…}, ::CuArray{…}, ::ComponentVector{…}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{…}, save_idxs::Nothing, kwargs::@kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/6yPzM/src/concrete_solve.jl:467
[31] _concrete_solve_adjoint
@ ~/.julia/packages/SciMLSensitivity/6yPzM/src/concrete_solve.jl:359 [inlined]
[32] #_solve_adjoint#75
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1537 [inlined]
[33] _solve_adjoint
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1510 [inlined]
[34] #rrule#4
@ ~/.julia/packages/DiffEqBase/c8MAQ/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
[35] rrule
@ ~/.julia/packages/DiffEqBase/c8MAQ/ext/DiffEqBaseChainRulesCoreExt.jl:22 [inlined]
[36] rrule
@ ~/.julia/packages/ChainRulesCore/I1EbV/src/rules.jl:140 [inlined]
[37] chain_rrule_kw
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:235 [inlined]
[38] macro expansion
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
[39] _pullback
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87 [inlined]
[40] _apply
@ ./boot.jl:948 [inlined]
[41] adjoint
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
[42] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[43] #solve#51
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
[44] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::InterpolatingAdjoint{…}, ::Nothing, ::Nothing, ::Val{…}, ::@kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[45] _apply
@ ./boot.jl:948 [inlined]
[46] adjoint
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
[47] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[48] solve
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
[49] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[50] #2
@ /mnt/research/lux/Lux.jl/src/helpers/compact.jl:345 [inlined]
[51] _pullback(::Zygote.Context{…}, ::var"#2#3", ::@NamedTuple{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[52] CompactLuxLayer
@ /mnt/research/lux/Lux.jl/src/helpers/compact.jl:480 [inlined]
[53] _pullback(::Zygote.Context{…}, ::CompactLuxLayer{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[54] apply
@ ~/.julia/packages/LuxCore/qeN7D/src/LuxCore.jl:175 [inlined]
[55] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::CompactLuxLayer{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[56] applychain
@ /mnt/research/lux/Lux.jl/src/layers/containers.jl:0 [inlined]
[57] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[58] Chain
@ /mnt/research/lux/Lux.jl/src/layers/containers.jl:510 [inlined]
[59] _pullback(::Zygote.Context{…}, ::Chain{…}, ::CuArray{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[60] AbstractLossFunction
@ /mnt/research/lux/Lux.jl/src/helpers/losses.jl:8 [inlined]
[61] _pullback(::Zygote.Context{…}, ::CrossEntropyLoss{…}, ::Chain{…}, ::ComponentVector{…}, ::@NamedTuple{…}, ::Tuple{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[62] pullback
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90 [inlined]
[63] pullback
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
[64] compute_gradients(::AutoZygote, objective_function::CrossEntropyLoss{…}, data::Tuple{…}, ts::Lux.Experimental.TrainState{…})
@ LuxZygoteExt /mnt/research/lux/Lux.jl/ext/LuxZygoteExt/training.jl:3
[65] single_train_step!(backend::AutoZygote, obj_fn::CrossEntropyLoss{true, Nothing, Int64, typeof(mean), Nothing}, data::Tuple{CuArray{…}, CuArray{…}}, ts::Lux.Experimental.TrainState{Nothing, Nothing, Chain{…}, ComponentVector{…}, @NamedTuple{…}, Adam, Optimisers.Leaf{…}})
@ Lux.Experimental /mnt/research/lux/Lux.jl/src/contrib/training.jl:218
[66] train(model_function::Function; cpu::Bool, kwargs::@kwargs{})
@ Main ./REPL[18]:17
[67] train(model_function::Function)
@ Main ./REPL[18]:1
[68] top-level scope
@ REPL[19]:1
Some type information was truncated. Use show(err) to see complete types.

@DhairyaLGandhi this is why we don't put unnecessary refactors into big PRs.