ChainRules pullback doesn't support ZeroTangent.
torfjelde opened this issue · 8 comments
It seemsthe rrule
for mapped/broadcasted f
doesn't support ZeroTangent
:
julia> Zygote.gradient(x -> Bijectors.with_logabsdet_jacobian(stacked, x)[2], randn(3))
ERROR: MethodError: no method matching length(::ChainRulesCore.ZeroTangent)
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{DataStructures.OrderedRobinDict, DataStructures.RobinDict}) at ~/.julia/packages/DataStructures/59MD0/src/ordered_robin_dict.jl:86
length(::Union{DataStructures.SortedDict, DataStructures.SortedMultiDict, DataStructures.SortedSet}) at ~/.julia/packages/DataStructures/59MD0/src/container_loops.jl:322
...
Stacktrace:
[1] length(g::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}})
@ Base ./generator.jl:50
[2] _similar_shape(itr::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}}, #unused#::Base.HasLength)
@ Base ./array.jl:663
[3] collect(itr::Base.Generator{ChainRulesCore.ZeroTangent, ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.var"#1#2"{Tuple{Float64, Float64}, Float64}})
@ Base ./array.jl:786
[4] map(f::Function, A::ChainRulesCore.ZeroTangent)
@ Base ./abstractarray.jl:2961
[5] (::ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.WithLadjOnMappedPullback{Tuple{Float64, Float64}})(thunked_ΔΩ::ChainRulesCore.Tangent{Any, Tuple{ChainRulesCore.ZeroTangent, Float64}})
@ ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt ~/.julia/packages/ChangesOfVariables/qC6bf/ext/ChangesOfVariablesChainRulesCoreExt.jl:12
[6] ZBack
@ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:211 [inlined]
[7] Pullback
@ ~/.julia/packages/ChangesOfVariables/qC6bf/src/with_ladj.jl:121 [inlined]
[8] (::Zygote.Pullback{Tuple{typeof(with_logabsdet_jacobian), Base.Fix1{typeof(broadcast), typeof(exp)}, Vector{Float64}}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:x, Zygote.Context{false}, Base.Fix1{typeof(broadcast), typeof(exp)}, typeof(exp)}}, Zygote.Pullback{Tuple{Type{Base.Fix1}, typeof(with_logabsdet_jacobian), typeof(exp)}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}, Zygote.Pullback{Tuple{typeof(Base._stable_typeof), typeof(exp)}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}, Zygote.var"#2176#back#309"{Zygote.Jnew{Base.Fix1{typeof(with_logabsdet_jacobian), typeof(exp)}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(with_logabsdet_jacobian)}, typeof(with_logabsdet_jacobian)}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(exp)}, typeof(exp)}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#420"}}}, Zygote.var"#4118#back#1381"{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), Base.Fix1{typeof(with_logabsdet_jacobian), typeof(exp)}, Vector{Float64}}, Tuple{Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4086#back#1368"{Zygote.var"#∇broadcasted#1364"{Tuple{Vector{Float64}}, Vector{Tuple{Tuple{Float64, Float64}, Zygote.var"#2379#back#440"{Zygote.Pullback{Tuple{Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, Float64}, Tuple{Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:x, Zygote.Context{false}, Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, typeof(exp)}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:f, Zygote.Context{false}, Zygote.var"#fallback_Fix1#439"{typeof(exp), typeof(with_logabsdet_jacobian)}, typeof(with_logabsdet_jacobian)}}, Zygote.Pullback{Tuple{typeof(with_logabsdet_jacobian), typeof(exp), Float64}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#exp_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}}}}, Val{2}}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), Vector{Float64}}, Tuple{}}, Zygote.var"#1982#back#200"{typeof(identity)}}}}, Zygote.ZBack{ChangesOfVariables.ChangesOfVariablesChainRulesCoreExt.WithLadjOnMappedPullback{Tuple{Float64, Float64}}}, Zygote.var"#2149#back#299"{Zygote.var"#back#298"{:f, Zygote.Context{false}, Base.Fix1{typeof(broadcast), typeof(exp)}, typeof(broadcast)}}}})(Δ::Tuple{Nothing, Float64})
@ Zygote ./compiler/interface2.jl:0
...
Thanks @torfjelde , will fix!
Cheers @oschulz!
I posted it because at the time I didn't have time to address it myself.
But on closer inspection, it seems it requires a bit more changes than I originally had hoped. That is, it seems to me like we need to change the definition of WithLadjOnMappedPullback
to also wrap the size of ys
, since we now need to return a ZeroTangent
for every y
in ys
? Or am I misunderstanding here?
Uh, good question. I think you're right ...
Ah, wait - @torfjelde, I think we don't have to fill an array with ZeroTangent()
(though I'm not quite sure why the result is NoTangent()
instead of ZeroTangent()
here:
julia> using ChainRulesCore, ChainRules
julia> unthunk(rrule(sum, rand(5))[2](ZeroTangent())[2])
NoTangent()
@oxinabox - Frames, if you have a moment, could you advise us? Do we need to pull back to a vector of NoTangent()
or can we just pull back to a single NoTangent()
?
In theory a vector of ZeroTangents
should be treated the same as a ZeroTangent
by the AD.
And that treatment is: Don't call the pullback, just return a ZeroTangent()
(same goes for NoTangent
)
Whether or not a particular AD actually does that is a question.
A lot of rules assume they are never passed ZeroTangent
or NoTangent
because they expect the AD system to handle it.
The fact that this one returns NoTangent
rather than the correct ZeroTangent
is probably because the rule writer didn't consider this case