Lux branching
ArnoStrouwen opened this issue · 7 comments
using Lux, ComponentArrays, OrdinaryDiffEq, SciMLSensitivity, Random#, Statistics
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)
θ, ax = getdata(p), getaxes(p)
function dxdt_(dx, x, p, t)
ps = ComponentArray(p, ax)
x1, x2 = x
dx[1] = x[2]
dx[2] = first(ann([t], ps, st))[1]^3
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_, x0, tspan, θ)
SciMLSensitivity.hasbranching(dxdt_,copy(x0),x0,θ,tspan[1])
Returns true
so ReverseDiff will not compile unless sensealg
manually specified.
The polyalgorithm is thus very slow on this example.
@avik-pal where is the branching? We should probably special case this.
There shouldn't be any. Most layers which have branching have special case in their states to compile those away.
Are we sure the component array call isn't creating a branch?
Are we sure the component array call isn't creating a branch?
Could be, I was not very precise with my language. Given how prominently ComponentArrays always feature in SciML+Lux tutorials, I see them as a packaged deal. Do you think we should move away from them?
No, we'd just need to specialize it. We could make a ComponentArrays extension that does the bypass https://github.com/SciML/FunctionProperties.jl/blob/main/src/FunctionProperties.jl#L86-L94 and see if that fixes it.
Based on those docs, I think we will have to special case Lux as well, it relies on quite a few branches being compiled out (though not in this particular case)
It seems Cassette is having larger issues on v1.10: SciML/FunctionProperties.jl#9