deprecate Flux.params
CarloLucibello opened this issue · 7 comments
Is there any reason why we keep it around?
For the need of having a vector or iterable over trainable leaves we can build something ( trainables(model)
?) on top of Functors.fleaves so that we have a function decoupled from Zygote
.
I think we were waiting for a couple more features to land so we could have parity with some of the remaining use cases people might use implicit params for. FluxML/Optimisers.jl#57 is the main one I can think of.
I think that's the only one left
Would there be an alternative way to perform copy!
between a flat vector and a Params
like object, or even probably directly into nn
(a Flux.Chain
), something like copy!(x, nn)
and copy!(nn, x)
?
Along these lines, I also wanted to ask if Flux.jl
would have ComponentArrays
used similar to Lux.jl
? And would it be optional like Lux.jl
with NamedTuple
being default for parameters?
That already exists, roughly:
julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1));
julia> st = Flux.state(model)
(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())),)
julia> Flux.loadmodel!(model, st); # this is a nested copyto!
julia> using ComponentArrays
julia> ca = ComponentArray(; Flux.state(model)...)
ComponentVector{Tuple{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}}}(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())))
julia> ca.layers[1].weight .= NaN
1×2 Matrix{Float32}:
NaN NaN
julia> Flux.loadmodel!(model, ca)
Chain(
Dense(2 => 1, tanh), # 3 parameters (some NaN)
Dense(1 => 1), # 2 parameters
) # Total: 4 arrays, 5 parameters, 276 bytes.
The caveats are (1) what Flux.state returns includes non-trainable parameters, (2) I've no idea what'll happen to shared parameters, ComponentArrays ignores them, and (3) this is designed for loading from disk not for use within gradients, so Zygote may hate it, but that's fixable. (Edit, (4) my use of ComponentArray does not seem to produce something backed by one big vector, e.g. getfield(ca, :data)
, maybe I need to read their docs.)
Flux.loadmodel!
is for nested structures, we also have Flux.destructure
which is about flat vectors of parameters (and should respect points 1,2,3).
Possibly OT here. But perhaps worth opening an issue... perhaps with an example of what you wish would work?
Hi Michael, thanks a lot for the detailed reply (and sorry for the delay in my reply), I wasn't aware of Flux.State
. My use case has been to use Flux.jl
with Optim.jl
which requires a flat vector, so with Flux.Params
I could use the existing copy!
provided by Zygote.jl
(earlier from FluxOptTools.jl
) between Flux.Params
and flat vector, and this was useful also to convert the gradient into a flat vector for Optim.jl
, of course all the usage of copy!
was outside Zygote's over-watch.
Now, if I understand correctly, I have to write my own copy!
for conversion between Flux.State
and flat vector object, and this would be useful also with the object (seems similar to st
) returned by Zygote gradient with the new Flux usage Zygote.gradient(loss, model)
, which is not very hard, but the problem like you mentioned - "(1) what Flux.state returns includes non-trainable parameters" needs to be tackled (does trainables(model)
is intend to solve this issue?).
And with regards to destructure
, it makes the whole process more expensive due to a new model created every single epoch, and I have observed this hurts performance, so I have kept it aside.
And with regards to ComponentArrays
, I think it works for situations where we have nested NamedTuple
s, in case of a neural network a layer wise NamedTuple
of NamedTuple
but Flux.State
doesn't return that but a Tuple
of NamedTuple
s, hence the discrepancy observed above, but doesn't seem to be conceptually far away from intended usage.
So for now I can write a copy!
between Flux.State
and flat vector ignoring the non-trainable parameters, but would be happy to know if trainables(model)
and ComponentArrays
solutions work! Thanks a lot!
Hi @kishore-nori, could you open a new issue and provide a specific example that we can reason on? Your case seems to be well served by destructure
, if it's slow we should try to understand why.
Sure will come up with a MWE and open an issue, thank you. By the way, I have realized that that idea of destructure!
(FluxML/Optimisers.jl#165) would be really beneficial and fit well for my purpose.