FluxML/Functors.jl

Recursive inference failure for multiple functors

gaurav-arya opened this issue · 3 comments

I've been using Functors in a case where I want to make sure it's really fast in the special case of small tuples. For a single functor, this is fine, but for functions of multiple functors...

julia> using Functors
julia> @btime Functors.fmap(+, (3,))
  0.859 ns (0 allocations: 0 bytes)
(3,)

julia> @btime Functors.fmap(+, (3,), (3,))
  404.070 ns (9 allocations: 464 bytes)
(6,)

JET.jl states the following for the second call

═════ 7 possible errors found ═════
┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:3 Functors.:(var"#fmap#134")(tuple(Functors.isleaf, Functors.DefaultWalk(), Functors.IdDict(), Functors.NoKeyword(), #self#, f, x), ys...)
│┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:11 fmap(tuple(_walk, f, x), ys...)
││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 walk(tuple(#132, x), ys...)
│││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:132 ret = walk.walk(tuple(recurse, x), ys...)
││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:92 walk.walk(tuple(recurse, x), ys...)
│││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:62 Functors._map(tuple(recurse, func), yfuncs...)
││││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:1 Functors.map(tuple(f), x...)
│││││││┌ @ tuple.jl:298 f(t[1], s[1])
││││││││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 fmap(tuple(getfield(#self#, :walk), getfield(#self#, :f)), xs...)
│││││││││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 Functors.fmap(::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, ::typeof(+), ::Int64, ::Int64)
││││││││││ failed to optimize: Functors.fmap(::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, ::typeof(+), ::Int64, ::Int64)
│││││││││└──────────────────────────────────────────────────
││││││││┌ @ /home/gaurav/.julia/dev/Functors/src/maps.jl:1 (::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)})(::Int64, ::Int64)
│││││││││ failed to optimize: (::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)})(::Int64, ::Int64)
││││││││└──────────────────────────────────────────────────
│││││││┌ @ tuple.jl:298 map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││││││ failed to optimize: map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││││└────────────────
││││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:1 Functors._map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││││ failed to optimize: Functors._map(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││││└───────────────────────────────────────────────────
│││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:59 (::DefaultWalk)(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││││ failed to optimize: (::DefaultWalk)(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││└────────────────────────────────────────────────────
││││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:92 (::ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││││ failed to optimize: (::ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││└────────────────────────────────────────────────────
│││┌ @ /home/gaurav/.julia/dev/Functors/src/walks.jl:127 (::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
││││ failed to optimize: (::Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword})(::Functors.var"#132#133"{Functors.CachedWalk{ExcludeWalk{DefaultWalk, typeof(+), typeof(Functors.isleaf)}, Functors.NoKeyword}, typeof(+)}, ::Tuple{Int64}, ::Tuple{Int64})
│││└─────────────────────────────────────────────────────

The "failed to optimize" comes from an OptimizationFailureReport which "will happen when there are (mutually) recursive calls and Julia compiler decided not to do inference in order to make sure the inference's termination". Seemingly, this only crops up when we provide multiple functors.

I'd like to tweak the Functors.jl code so that this case is optimized, although I'm not really sure where to start in fixing this so would recommend any tips:)

The bigger issue is that CachedWalk explicitly breaks inference by shoving and looking up values in an IdDict. Disabling the cache removes most of the allocations.

julia> @btime Functors.fmap(+, (3,), (3,));
  468.852 ns (9 allocations: 464 bytes)

julia> @btime Functors.fmap(+, (3,), (3,); cache=nothing);
  358.313 ns (2 allocations: 48 bytes)

Stripping off all the unnecessary intermediate walks and looking at the Cthulhu output, this doesn't seem like a inference failure due to recursion, but rather the case 2 of a runtime dispatch.

fmap(walk::Functors.AbstractWalk, f, x, ys...) @ Functors ~/.julia/packages/Functors/orBYx/src/maps.jl:1
Variables
  #self#::Core.Const(Functors.fmap)
  walk::Core.Const(Functors.DefaultWalk())
  f::Core.Const(+)
  x::Tuple{Int64}
  ys::Tuple{Tuple{Int64}}
  #23::Functors.var"#23#24"{Functors.DefaultWalk, typeof(+)}

Body::Tuple{Int64}
    @ ~/.julia/packages/Functors/orBYx/src/maps.jl:1 within `fmap`
1 ─ %1 = Functors.:(var"#23#24")::Core.Const(Functors.var"#23#24")
│   %2 = Core.typeof(walk)::Core.Const(Functors.DefaultWalk)
│   %3 = Core.typeof(f)::Core.Const(typeof(+))
│   %4 = Core.apply_type(%1, %2, %3)::Core.Const(Functors.var"#23#24"{Functors.DefaultWalk, typeof(+)})
│        (#23 = %new(%4, walk, f))
│   %6 = #23::Core.Const(Functors.var"#23#24"{Functors.DefaultWalk, typeof(+)}(Functors.DefaultWalk(), +))
│   %7 = Core.tuple(%6, x)::Tuple{Functors.var"#23#24"{Functors.DefaultWalk, typeof(+)}, Tuple{Int64}}
│   %8 = Core._apply_iterate(Base.iterate, walk, %7, ys)::Tuple{Int64} [constprop] Disabled by entry heuristic (limited accuracy)

This is corroborated by the @code_typed output containing an invoke.

julia> @code_typed Functors.fmap(Functors.DefaultWalk(), +, (3,), (3,))
CodeInfo(
1 ─ %1 = Core.getfield(ys, 1)::Tuple{Int64}
│   %2 = invoke walk(Functors.var"#23#24"{Functors.DefaultWalk, typeof(+)}(Functors.DefaultWalk(), +)::Function, x::Tuple{Int64}, %1::Tuple{Int64})::Tuple{Int64}
└──      return %2
) => Tuple{Int64}

Hm, I see. I don't know enough to really parse the implications of your logs above: do they point at where the runtime dispatch could be occuring? I'm happy to disable the cache in any case, so I'm mostly concerned about the failure to optimize with the cache disabled.

I've tried hacking at the code to fix things, and the only thing that's worked is changing the recursive structure itself (in a hacky way): gaurav-arya@288059a

The outputs above are mostly with the cache disabled. A proper fix for this would likely require changes on the compiler side.