`fmapreduce`
Opened this issue · 0 comments
mcabbott commented
This package probably wants a way to write mapreduce
, to replace e.g. sum(norm(p) for p in params(m))
in Flux. This seems like the minimal attempt, but it's not Zygote-friendly. Can this be fixed, and is there a better way?
julia> using Functors, Zygote
julia> const INIT = Base._InitialValue();
julia> function fmapreduce(f, op, x; init = INIT, walk = (f, x) -> foreach(f, Functors.children(x)), kw...)
fmap(x; walk, kw...) do y
init = init===INIT ? f(y) : op(init, f(y))
end
init===INIT ? Base.mapreduce_empty(f, op) : init
end
fmapreduce (generic function with 1 method)
julia> m = ([1,2], (x=[3,4], y=5), 6);
julia> fmapreduce(sum, +, m)
21
julia> gradient(fmapreduce, sum, +, m)
(nothing, nothing, nothing)