FluxML/Functors.jl

What's the cache for?

willtebbutt opened this issue · 3 comments

I was very pleased to discover that this is a thing that's been carved out from Flux, but was slightly surprised by the following performance:

using Functors, BenchmarkTools

using Functors: functor

struct Bar{T}
    x::T
end

@functor Bar

bar = Bar(5.0)

julia> @benchmark fmap(Float32, bar)
BenchmarkTools.Trial:
  memory estimate:  608 bytes
  allocs estimate:  16
  --------------
  minimum time:     952.304 ns (0.00% GC)
  median time:      984.652 ns (0.00% GC)
  mean time:        1.040 μs (1.92% GC)
  maximum time:     76.549 μs (96.27% GC)
  --------------
  samples:          10000
  evals/sample:     23

Digging down a little, functor seems to be performant:

julia> @benchmark functor($bar)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.455 ns (0.00% GC)
  median time:      1.470 ns (0.00% GC)
  mean time:        1.589 ns (0.00% GC)
  maximum time:     39.906 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

👍

Similarly, isleaf seems to be fine:

using Functors: isleaf

julia> @benchmark isleaf($bar)
BenchmarkTools.Trial:
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     0.020 ns (0.00% GC)
  median time:      0.032 ns (0.00% GC)
  mean time:        0.030 ns (0.00% GC)
  maximum time:     0.050 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

👍

So there's something else going on in fmap and fmap1 that I assume has something to do with the IdDict that's being used. So, I would be interested to know a) what the need for the cache is (it's kind of un-obvious to me) and b) whether there's a way to get rid of all of this overhead as it seems kind of unnecessary in this simple case?

edit: I realised while out on a walk that it's probably something to do with diamonds in the dependency graph for any particular data structure. Is this the case?

You hit the nail on the head. If we do something like d = Dense(10, 10); model = Chain(d, d) then we currently treat this as a graph, rather than a tree. So d is only seen once when mapping, and gets updated the same way in both locations, with the graph structure preserved. And optimisers only update d once, etc.

I've generally not expected that functors would need to be all that fast (though they certainly could all be inlined away, if not for this caching bit). Is there a particular need or were you just curious? It may be possible to optimise it away in some cases.

We could get rid of this behaviour entirely; there are much better ways to get weight sharing and we're moving to a more functional worldview anyway, so enforcing that the above example gets viewed as a tree wouldn't be all that bad.

I think I agree that it's unlikely to be a bottleneck in my code either at the minute. There are some slightly niche Bayesian numerics applications that I've had my eye on for a while where this overhead might become significant, but they're a way down the line.

Again, this was more out of curiosity.

Might be nice to add some comments / developer docs explaining this.

Resolved by #4