Taggable functors
darsnack opened this issue · 4 comments
In Flux, we have trainable
to designate a subset of leaves as nodes to walk when updating parameters for training. In FluxPrune.jl, I defined pruneable
to designate a subset of leaves for pruning (note that these cannot be the same as the trainable
nodes).
Right now this creates an unfortunate circumstance as discussed in FluxML/Flux.jl#1946. Users need to @functor
their types, remember to define trainable
if necessary. Potentially, to use FluxPrune.jl, they might want to remember to define pruneable
. On the developer side of things, we can use the walk
keyword of fmap
to walk the differently labeled leaf nodes. But this usually requires defining a separate walk function based on the subset that you are hoping to target.
An alternative would be to build this information directly into what @functor
defines. Right now, each child of a functor has a name and a value. I propose adding "tags" which would be a tuple of symbols. Then we could do something like
@functor Conv trainable=(weight, bias) pruneable=(weight,)
Ideally, this mechanism should be dynamic, meaning that if Flux.jl already defines the trainable leaves of a type, then another package like FluxPrune.jl should be able to add a pruneable tag on top of that.
My hope is that we make it easier on users by only having one line for making your type Flux-compatible. And we make it easier on developers by making it easy to filter nodes when walking by tag. I haven't spent a lot of time on the implementation aspect, but I just wanted to float the notion of tags first and get some feedback.
Definitely thought about this before. The part I got stuck on was where/how to store this tag metadata. Do you have any proposals there?
I was thinking that functor
should be broken up. Instead we can have children(x::T)
, rebuilder(::Type{T})
, and tags(::Type{T})
. Of course, the convenience macro, @functor
, would define all three. I haven't thought too deeply about a simple convenience syntax for the macro that would support the initial declaration + adding more tags.
children
would return the named tuple that is returned by functor
, and rebuilder
would return the function that puts the struct back together. tags
would return a named tuple similar to children
but instead of actual values for each key, it stores a tuple of symbols corresponding to the tags (defaulting to empty).
I like the idea of splitting things up. How would e.g. trainable
it in under this design? @flexiblefunctor
is another question mark. rebuilder
sounds vaguely like ProjectTo
, perhaps there's something we could learn from that too.
Maybe @functor T
defines functor(::Type{T}, x, ::Val)
, and @functor T trainable=(x,y)
defines that and also functor(::Type{T}, x, ::Val{:trainable})
.
Since it's a weird macro anyway, perhaps it can check for the existence of the simplest method before defining (so as not to over-write). Or some notation like @functor T +pruneable=(w,)
could tell it to define only the ::Val{: pruneable}
method?