backing error when nesting forward-mode
oxinabox opened this issue · 2 comments
oxinabox commented
Here is a trivial example of nested forwards mode
using Diffractor
f(x) = 3x^2
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10, (1.0,)))
Here is the output from running that:
julia> Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10, (1.0,)))
ERROR: ArgumentError: Tangent for the primal Diffractor.UniformTangent{ChainRulesCore.ZeroTangent} should be backed by a NamedTuple type, not by Tuple{ChainRulesCore.ZeroTangent}.
Stacktrace:
[1] _backing_error(P::Type, G::Type, E::Type)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:62
[2] ChainRulesCore.Tangent{Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{ChainRulesCore.ZeroTangent}}(backing::Tuple{ChainRulesCore.ZeroTangent})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:36
[3] (ChainRulesCore.Tangent{Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}})(args::ChainRulesCore.ZeroTangent)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:48
[4] partial(x::Diffractor.CompositeBundle{1, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{Diffractor.TangentBundle{1, ChainRulesCore.ZeroTangent, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}}}, i::Int64)
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:7
[5] first_partial(x::Diffractor.CompositeBundle{1, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{Diffractor.TangentBundle{1, ChainRulesCore.ZeroTangent, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:11
[6] map
@ ./tuple.jl:291 [inlined]
[7] map(f::typeof(Diffractor.first_partial), t::Tuple{Diffractor.TangentBundle{1, typeof(Diffractor._TangentBundle), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, Diffractor.TangentBundle{1, Val{1}, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, Diffractor.TangentBundle{1, typeof(f), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, Diffractor.CompositeBundle{1, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}, Tuple{Diffractor.TangentBundle{1, ChainRulesCore.ZeroTangent, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}}}})
@ Base ./tuple.jl:292
[8] (::Diffractor.∂☆internal{1})(::Diffractor.TangentBundle{1, typeof(Diffractor._TangentBundle), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:110
[9] (::Diffractor.∂☆{1})(::Diffractor.TangentBundle{1, typeof(Diffractor._TangentBundle), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:139
[10] TangentBundle
@ ~/.julia/packages/Diffractor/HBYjZ/src/tangent.jl:251 [inlined]
[11] (::Diffractor.∂☆recurse{1})(::Diffractor.TangentBundle{1, Type{Diffractor.TangentBundle{1, B, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}} where B}, Diffractor.UniformTangent{ChainRulesCore.NoTangent}}, ::Diffractor.TangentBundle{1, typeof(f), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/recurse_fwd.jl:0
[12] (::Diffractor.∂☆internal{1})(::Diffractor.TangentBundle{1, Type{Diffractor.TangentBundle{1, B, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}} where B}, Diffractor.UniformTangent{ChainRulesCore.NoTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:112
[13] (::Diffractor.∂☆{1})(::Diffractor.TangentBundle{1, Type{Diffractor.TangentBundle{1, B, Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}} where B}, Diffractor.UniformTangent{ChainRulesCore.NoTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:139
[14] g
@ ~/JuliaEnvs/DAECompiler.jl/scratch/jac_scratch.jl:54 [inlined]
[15] (::Diffractor.∂☆recurse{1})(::Diffractor.TangentBundle{1, typeof(g), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Diffractor.TangentBundle{1, Int64, Diffractor.TaylorTangent{Tuple{Float64}}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/recurse_fwd.jl:0
[16] (::Diffractor.∂☆internal{1})(::Diffractor.TangentBundle{1, typeof(g), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:112
[17] (::Diffractor.∂☆{1})(::Diffractor.TangentBundle{1, typeof(g), Diffractor.UniformTangent{ChainRulesCore.ZeroTangent}}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/HBYjZ/src/stage1/forward.jl:139
[18] top-level scope
@ ~/JuliaEnvs/DAECompiler.jl/scratch/jac_scratch.jl:55
I am not sure how we should handle this.
I suspect it is possible to rewrite some (maybe all?) case to do ∂☆{2}
but I would need to do some thinking.
we definately shouldn't be just erroring though.
oxinabox commented
So I believe the cause of this is that CompositeBundle
works (/is tested) to represent the tangent bundle for Tuples
only.
But its being asked to represent the tangent of structs (in particular for the tangent bundle struct),
but when you pull the partial
out for a struct, it pulls out a CRC.Tangent(P, <:Tuple)
and that is what gives the error
Here is a simpler case that triggers it without nesting AD
struct Foo
x
y
end
foo_dub(x) = Foo(x, 2x)
dz = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(foo_dub), Diffractor.TaylorBundle{1}(10.0, (1.0,)))
Diffractor.first_partial(dz)
erroring with
julia> Diffractor.first_partial(dz)
ERROR: ArgumentError: Tangent for the primal Foo should be backed by a NamedTuple type, not by Tuple{Float64, Float64}.
Stacktrace:
[1] _backing_error(P::Type, G::Type, E::Type)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:62
[2] ChainRulesCore.Tangent{Foo, Tuple{Float64, Float64}}(backing::Tuple{Float64, Float64})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:36
[3] (ChainRulesCore.Tangent{Foo})(::Float64, ::Vararg{Float64})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/tangent.jl:48
[4] partial(x::Diffractor.CompositeBundle{1, Foo, Tuple{Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}, Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}}}, i::Int64)
@ Diffractor ~/JuliaEnvs/DAECompiler.jl/dev/Diffractor/src/stage1/forward.jl:7
[5] first_partial(x::Diffractor.CompositeBundle{1, Foo, Tuple{Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}, Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}}}})
@ Diffractor ~/JuliaEnvs/DAECompiler.jl/dev/Diffractor/src/stage1/forward.jl:11
[6] top-level scope
@ ~/JuliaEnvs/DAECompiler.jl/scratch/jac_scratch.jl:65