FluxML/Functors.jl

`fmapreduce`

Opened this issue · 0 comments

This package probably wants a way to write mapreduce, to replace e.g. sum(norm(p) for p in params(m)) in Flux. This seems like the minimal attempt, but it's not Zygote-friendly. Can this be fixed, and is there a better way?

julia> using Functors, Zygote

julia> const INIT = Base._InitialValue();

julia> function fmapreduce(f, op, x; init = INIT, walk = (f, x) -> foreach(f, Functors.children(x)), kw...)
         fmap(x; walk, kw...) do y
           init = init===INIT ? f(y) : op(init, f(y))
         end
         init===INIT ? Base.mapreduce_empty(f, op) : init
       end
fmapreduce (generic function with 1 method)

julia> m = ([1,2], (x=[3,4], y=5), 6);

julia> fmapreduce(sum, +, m)
21

julia> gradient(fmapreduce, sum, +, m)
(nothing, nothing, nothing)