FluxML/Functors.jl

Taggable functors

darsnack opened this issue · 4 comments

In Flux, we have trainable to designate a subset of leaves as nodes to walk when updating parameters for training. In FluxPrune.jl, I defined pruneable to designate a subset of leaves for pruning (note that these cannot be the same as the trainable nodes).

Right now this creates an unfortunate circumstance as discussed in FluxML/Flux.jl#1946. Users need to @functor their types, remember to define trainable if necessary. Potentially, to use FluxPrune.jl, they might want to remember to define pruneable. On the developer side of things, we can use the walk keyword of fmap to walk the differently labeled leaf nodes. But this usually requires defining a separate walk function based on the subset that you are hoping to target.

An alternative would be to build this information directly into what @functor defines. Right now, each child of a functor has a name and a value. I propose adding "tags" which would be a tuple of symbols. Then we could do something like

@functor Conv trainable=(weight, bias) pruneable=(weight,)

Ideally, this mechanism should be dynamic, meaning that if Flux.jl already defines the trainable leaves of a type, then another package like FluxPrune.jl should be able to add a pruneable tag on top of that.

My hope is that we make it easier on users by only having one line for making your type Flux-compatible. And we make it easier on developers by making it easy to filter nodes when walking by tag. I haven't spent a lot of time on the implementation aspect, but I just wanted to float the notion of tags first and get some feedback.

Definitely thought about this before. The part I got stuck on was where/how to store this tag metadata. Do you have any proposals there?

I was thinking that functor should be broken up. Instead we can have children(x::T), rebuilder(::Type{T}), and tags(::Type{T}). Of course, the convenience macro, @functor, would define all three. I haven't thought too deeply about a simple convenience syntax for the macro that would support the initial declaration + adding more tags.

children would return the named tuple that is returned by functor, and rebuilder would return the function that puts the struct back together. tags would return a named tuple similar to children but instead of actual values for each key, it stores a tuple of symbols corresponding to the tags (defaulting to empty).

I like the idea of splitting things up. How would e.g. trainable it in under this design? @flexiblefunctor is another question mark. rebuilder sounds vaguely like ProjectTo, perhaps there's something we could learn from that too.

Maybe @functor T defines functor(::Type{T}, x, ::Val), and @functor T trainable=(x,y) defines that and also functor(::Type{T}, x, ::Val{:trainable}).

Since it's a weird macro anyway, perhaps it can check for the existence of the simplest method before defining (so as not to over-write). Or some notation like @functor T +pruneable=(w,) could tell it to define only the ::Val{: pruneable} method?