FluxML/Flux.jl

deprecate Flux.params

CarloLucibello opened this issue · 7 comments

Is there any reason why we keep it around?

For the need of having a vector or iterable over trainable leaves we can build something ( trainables(model) ?) on top of Functors.fleaves so that we have a function decoupled from Zygote.

I think we were waiting for a couple more features to land so we could have parity with some of the remaining use cases people might use implicit params for. FluxML/Optimisers.jl#57 is the main one I can think of.

I think that's the only one left

Would there be an alternative way to perform copy! between a flat vector and a Params like object, or even probably directly into nn (a Flux.Chain), something like copy!(x, nn) and copy!(nn, x)?

Along these lines, I also wanted to ask if Flux.jl would have ComponentArrays used similar to Lux.jl? And would it be optional like Lux.jl with NamedTuple being default for parameters?

That already exists, roughly:

julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1));

julia> st = Flux.state(model)
(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())),)

julia> Flux.loadmodel!(model, st);  # this is a nested copyto!

julia> using ComponentArrays

julia> ca = ComponentArray(; Flux.state(model)...)
ComponentVector{Tuple{@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}, @NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}, σ::Tuple{}}}}(layers = ((weight = Float32[0.5213037 0.35699493], bias = Float32[0.0], σ = ()), (weight = Float32[0.96851003;;], bias = Float32[0.0], σ = ())))

julia> ca.layers[1].weight .= NaN
1×2 Matrix{Float32}:
 NaN  NaN

julia> Flux.loadmodel!(model, ca)
Chain(
  Dense(2 => 1, tanh),                  # 3 parameters  (some NaN)
  Dense(1 => 1),                        # 2 parameters
)                   # Total: 4 arrays, 5 parameters, 276 bytes.

The caveats are (1) what Flux.state returns includes non-trainable parameters, (2) I've no idea what'll happen to shared parameters, ComponentArrays ignores them, and (3) this is designed for loading from disk not for use within gradients, so Zygote may hate it, but that's fixable. (Edit, (4) my use of ComponentArray does not seem to produce something backed by one big vector, e.g. getfield(ca, :data), maybe I need to read their docs.)

Flux.loadmodel! is for nested structures, we also have Flux.destructure which is about flat vectors of parameters (and should respect points 1,2,3).

Possibly OT here. But perhaps worth opening an issue... perhaps with an example of what you wish would work?

Hi Michael, thanks a lot for the detailed reply (and sorry for the delay in my reply), I wasn't aware of Flux.State. My use case has been to use Flux.jl with Optim.jl which requires a flat vector, so with Flux.Params I could use the existing copy! provided by Zygote.jl (earlier from FluxOptTools.jl) between Flux.Params and flat vector, and this was useful also to convert the gradient into a flat vector for Optim.jl, of course all the usage of copy! was outside Zygote's over-watch.

Now, if I understand correctly, I have to write my own copy! for conversion between Flux.State and flat vector object, and this would be useful also with the object (seems similar to st ) returned by Zygote gradient with the new Flux usage Zygote.gradient(loss, model), which is not very hard, but the problem like you mentioned - "(1) what Flux.state returns includes non-trainable parameters" needs to be tackled (does trainables(model) is intend to solve this issue?).

And with regards to destructure, it makes the whole process more expensive due to a new model created every single epoch, and I have observed this hurts performance, so I have kept it aside.

And with regards to ComponentArrays, I think it works for situations where we have nested NamedTuples, in case of a neural network a layer wise NamedTuple of NamedTuple but Flux.State doesn't return that but a Tuple of NamedTuples, hence the discrepancy observed above, but doesn't seem to be conceptually far away from intended usage.

So for now I can write a copy! between Flux.State and flat vector ignoring the non-trainable parameters, but would be happy to know if trainables(model) and ComponentArrays solutions work! Thanks a lot!

Hi @kishore-nori, could you open a new issue and provide a specific example that we can reason on? Your case seems to be well served by destructure, if it's slow we should try to understand why.

Sure will come up with a MWE and open an issue, thank you. By the way, I have realized that that idea of destructure! (FluxML/Optimisers.jl#165) would be really beneficial and fit well for my purpose.