JuliaArrays/LazyArrays.jl

Zygote compat is lacking

torfjelde opened this issue · 11 comments

Zygote doesn't interact too nicely with LazyArrays.jl it seems, e.g.:

julia> f(x) = sum(BroadcastArray(exp, x))
f (generic function with 1 method)

julia> Zygote.gradient(f, randn(10))
ERROR: type Array has no field f
Stacktrace:
  [1] adjoint
    @ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:229 [inlined]
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [3] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:50 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(LazyArrays.call), ::ArrayLayouts.DenseColumnMajor, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:52 [inlined]
  [6] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:82 [inlined]
  [7] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:57 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::Type{BroadcastArray}, ::typeof(exp), ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./REPL[48]:1 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [12] pullback
    @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [14] top-level scope
    @ REPL[50]:1

julia> g(x) = sum(LazyArray(@~ exp.(x)))
g (generic function with 1 method)

julia> Zygote.gradient(g, randn(10))
ERROR: MethodError: no method matching LazyArray(::Vector{Float64})
Closest candidates are:
  LazyArray(::Base.Broadcast.Broadcasted) at ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:35
  LazyArray(::Applied) at ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:193
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(ctx::Zygote.Context{false}, f::Type{LazyArray}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:9
 [3] _pullback
   @ ./REPL[53]:1 [inlined]
 [4] _pullback(ctx::Zygote.Context{false}, f::typeof(g), args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [5] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [6] pullback
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [7] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [8] top-level scope
   @ REPL[54]:1

The first error can be "fixed" (I'm not entirely certain if this is the right way to go about it) by defining a chain rule:

julia> using ChainRulesCore

julia> function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...)
           return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(f, randn(10))
([0.24117702568683322, 2.478340448616497, 2.433266795642693, 1.6163793920298133, 1.8859252985478665, 3.9539878829654223, 1.2578105524502685, 0.48545348574922, 0.8710494256114425, 3.0853524634917076],)

Maybe the rest can be addressed this way too.

Are rules from CRC something that would be welcomed?

Hmm.... that's a good question.... I'm usually hesitant to add "*Core.jl" dependencies because a lot of them are of questionable usage but ChainRulesCore.jl might be an exception.

One alternative solution is to make a glue package a la FastTransformsForwardDiff. (I'm wondering whether that should have been FastTransformsChainRulesCore.jl...)

Either alternative is okay with me:)

You just say which alternative you prefer, and I can try to contribute towards it.

Let's put it in a separate package for now so we can work out the kinks. We can always merge it back here (in the event there's a good reason to have it).

It seems this is a good use case for weak deps. Some packages already started moving ChainRules definition to weak deps. The definitions would be loaded only on Julia >= 1.9 (if you don't want to uae Requires on older Julia versions) but I think it would be the better long-term solution.

It woul suck if we'd have to wait until Julia 1.9 before we could make use of this though 😕

I assume it already works with the beta version, so I think you can already use it without compiling julia.

Can we do a separate package that works now, and becomes a weak dependency in Julia v1.9?

If a weak dependency is loaded, an extension (usually a single file) in the ext subfolder is loaded (and precompiled, in contrast to the Requires hacks!). AFAIK there are no separate packages involved or loaded in the extension apart from the weak dependency and the package + hard dependencies, and making the glue package a hard dependency would defeat its purpose. An example is shown in this PR: JuliaMath/ChangesOfVariables.jl#12

I see. I think a weak dependency hear would be fine. I would suggest forgetting the separate project and just requiring v1.9

We use weak deps for ChangesOfVariables.jl now, and it works like a charm on Julia v1.9:

julia> @time_imports import ChangesOfVariables
      0.6 ms  ChangesOfVariables

julia> @time_imports import ChainRulesCore
      0.1 ms  Compat
     58.9 ms  ChainRulesCore
      0.4 ms  ChangesOfVariables → ChainRulesCoreExt