JuliaMath/ChangesOfVariables.jl

ChainRules pullback doesn't support ZeroTangent.

torfjelde opened this issue · 8 comments

It seemsthe rrule for mapped/broadcasted f doesn't support ZeroTangent:

julia> Zygote.gradient(x -> Bijectors.with_logabsdet_jacobian(stacked, x)[2], randn(3))
ERROR: MethodError: no method matching length(::ChainRulesCore.ZeroTangent)
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{DataStructures.OrderedRobinDict, DataStructures.RobinDict}) at ~/.julia/packages/DataStructures/59MD0/src/ordered_robin_dict.jl:86
  length(::Union{DataStructures.SortedDict, DataStructures.SortedMultiDict, DataStructures.SortedSet}) at ~/.julia/packages/DataStructures/59MD0/src/container_loops.jl:322
  ...
Stacktrace:
  [1] length(g::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}})
    @ Base ./generator.jl:50
  [2] _similar_shape(itr::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}}, #unused#::Base.HasLength)
    @ Base ./array.jl:663
  [3] collect(itr::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}})
    @ Base ./array.jl:786
  [4] map(f::Function, A::ChainRulesCore.ZeroTangent)
    @ Base ./abstractarray.jl:2961
  [5] (::ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.WithLadjOnMappedPullback{Tuple{Float64, Float64}})(thunked_ΔΩ::ChainRulesCore.Tangent{Any, Tuple{ChainRulesCore.ZeroTangent, Float64}})
    @ ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt ~/.julia/packages/ChangesOfVariables/qC6bf/ext/ChangesOfVariablesChainRulesCoreExt.jl:12
  [6] ZBack
    @ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:211 [inlined]
  [7] Pullback
    @ ~/.julia/packages/ChangesOfVariables/qC6bf/src/with_ladj.jl:121 [inlined]
  [8] (::Zygote.Pullback{Tuple{typeof(with_logabsdet_jacobian), Base.Fix1{typeof(broadcast), typeof(exp)}, Vector{Float64}}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:x, Zygote.Context{false}, Base.Fix1{typeof(broadcast), typeof(exp)}, typeof(exp)}}, Zygote.Pullback{Tuple{Type{Base.Fix1}, typeof(with_logabsdet_jacobian), typeof(exp)}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.Pullback{Tuple{typeof(Base._stable_typeof), typeof(exp)}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.var"#2176#back#309"{Zygote.Jnew{Base.Fix1{typeof(with_logabsdet_jacobian), typeof(exp)}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(with_logabsdet_jacobian)}, typeof(with_logabsdet_jacobian)}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(exp)}, typeof(exp)}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}}}, Zygote.var"#4118#back#1381"{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), Base.Fix1{typeof(with_logabsdet_jacobian), typeof(exp)}, Vector{Float64}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Float64}}, Vector{Tuple{Tuple{Float64, Float64}, Zygote.var"#2379#back#440"{Zygote.Pullback{Tuple{Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, Float64}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:x, Zygote.Context{false}, Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, typeof(exp)}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:f, Zygote.Context{false}, Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, typeof(with_logabsdet_jacobian)}}, Zygote.Pullback{Tuple{typeof(with_logabsdet_jacobian), typeof(exp), Float64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#exp_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}}}}, Val{2}}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}, Zygote.ZBack{ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.WithLadjOnMappedPullback{Tuple{Float64, Float64}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:f, Zygote.Context{false}, Base.Fix1{typeof(broadcast), typeof(exp)}, typeof(broadcast)}}}})(Δ::Tuple{Nothing, Float64})
    @ Zygote ./compiler/interface2.jl:0
...

Thanks @torfjelde , will fix!

Cheers @oschulz!

I posted it because at the time I didn't have time to address it myself.

But on closer inspection, it seems it requires a bit more changes than I originally had hoped. That is, it seems to me like we need to change the definition of WithLadjOnMappedPullback to also wrap the size of ys, since we now need to return a ZeroTangent for every y in ys? Or am I misunderstanding here?

Uh, good question. I think you're right ...

Ah, wait - @torfjelde, I think we don't have to fill an array with ZeroTangent() (though I'm not quite sure why the result is NoTangent() instead of ZeroTangent() here:

julia> using ChainRulesCore, ChainRules

julia> unthunk(rrule(sum, rand(5))[2](ZeroTangent())[2])
NoTangent()

@oxinabox - Frames, if you have a moment, could you advise us? Do we need to pull back to a vector of NoTangent() or can we just pull back to a single NoTangent()?

In theory a vector of ZeroTangents should be treated the same as a ZeroTangent by the AD.
And that treatment is: Don't call the pullback, just return a ZeroTangent()
(same goes for NoTangent)

Whether or not a particular AD actually does that is a question.
A lot of rules assume they are never passed ZeroTangent or NoTangent because they expect the AD system to handle it.
The fact that this one returns NoTangent rather than the correct ZeroTangent is probably because the rule writer didn't consider this case

I've tried this in a few ways now, it's tricky to get right in all circumstances. I found a way to get decent AD-speed by changing the implementation of the primary a bit, though. So let's take the rrule out of now, I'd say: #21

Closed by #21 .