Tracing Zygote.gradient
jw3126 opened this issue · 1 comments
jw3126 commented
Tracing Zygote.gradient
does not work currently:
julia> module MWE
using Zygote, Ghost
f(x) = sum(x)
x = [1]
Zygote.gradient(f, x) # works
trace(Zygote.gradient,f, x)
end#module
ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Ghost.Variabl
e
Closest candidates are:
convert(::Type{T}, ::T) where T at essentials.jl:205
Ghost.Variable(::Union{Nothing, Integer}, ::Union{Nothing, Ghost.AbstractOp}) at /home/jan/.jul
ia/packages/Ghost/1rVCi/src/tape.jl:22
Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/1rVCi/src/tape.jl:22
Stacktrace:
[1] setproperty!(x::Ghost.Frame, f::Symbol, v::Nothing)
@ Base ./Base.jl:34
[2] set_return!(t::Ghost.IRTracer, arg_sid_ref::Base.RefValue{GlobalRef})
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:188
[3] (::Ghost.IRTracer)(::Zygote.ZBack{ChainRules.var"#isempty_pullback#446"}, ::Nothing)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:114
[4] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::Function, ::Vararg{Any, N} w
here N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[5] IRTracer
@ ./reduce.jl:501 [inlined]
[6] (::Ghost.IRTracer)(::typeof(∂(#sum#221)), ::Int64)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[7] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof(∂(#sum#221)), ::Varar
g{Any, N} where N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[8] IRTracer
@ ./reduce.jl:501 [inlined]
[9] (::Ghost.IRTracer)(::typeof(∂(sum)), ::Int64)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[10] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof(∂(sum)), ::Vararg{Any
, N} where N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[11] IRTracer
@ ./reduce.jl:528 [inlined]
[12] (::Ghost.IRTracer)(::typeof(∂(#sum#222)), ::Int64)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[13] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof(∂(#sum#222)), ::Varar
g{Any, N} where N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[14] IRTracer
@ ./reduce.jl:528 [inlined]
[15] (::Ghost.IRTracer)(::typeof(∂(sum)), ::Int64)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[16] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof(∂(sum)), ::Vararg{Any
, N} where N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[17] IRTracer
@ ./REPL[1]:3 [inlined]
[18] (::Ghost.IRTracer)(::typeof(∂(f)), ::Int64)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[19] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof(∂(f)), ::Vararg{Any,
N} where N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[20] IRTracer
@ ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41 [inlined]
[21] (::Ghost.IRTracer)(::Zygote.var"#46#47"{typeof(∂(f))}, ::Int64)
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[22] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::Function, ::Vararg{Any, N} w
here N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
[23] IRTracer
@ ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59 [inlined]
[24] (::Ghost.IRTracer)(::typeof(Zygote.gradient), ::typeof(Main.MWE.f), ::Vector{Int64})
@ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
[25] trace(::Function, ::Function, ::Vararg{Any, N} where N; is_primitive::Function, primitives:
:Nothing, ctx::Dict{Any, Any})
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:640
[26] trace(::Function, ::Function, ::Vararg{Any, N} where N)
@ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:630
[27] top-level scope
@ REPL[1]:6
dfdx commented
This is also fixed in #14.
However note that in this example you trace Zygote.gradient()
itself, not the derivative function created by Zygote.gradient()
. Perhaps you were looking for Zygote.pullback()
?
val, pb = Zygote.pullback(f, x)
trace(pb, 1.0) # 1.0 - value of the derivative
# output
((1-element Fill{Float64}: entries equal to 1.0,), Tape{Dict{Any, Any}}
inp %1::Zygote.var"#46#47"{typeof(∂(f))}
inp %2::Float64
%3 = getfield(%1, back)::typeof(∂(f))
const %4 = nothing::Nothing
%5 = getfield(%3, t)::Tuple{Zygote.var"#2675#back#615"{Zygote.var"#611#613"{Vector{Int64}}}}
%6 = getindex(%5, 1)::Zygote.var"#2675#back#615"{Zygote.var"#611#613"{Vector{Int64}}}
%7 = getfield(%6, #2674#_back)::Zygote.var"#611#613"{Vector{Int64}}
%8 = getfield(%7, xs)::Vector{Int64}
%9 = size(%8)::Tuple{Int64}
%10 = apply_type(FillArrays.Fill, Float64, 1)::UnionAll
%11 = apply_type(FillArrays.Fill, Float64, 1)::UnionAll
%12 = broadcasted(oneto, %9)::Broadcasted{}
%13 = materialize(%12)::Tuple{Base.OneTo{Int64}}
%14 = apply_type(FillArrays.Fill, Float64, 1, Tuple{Base.OneTo{Int64}})::DataType
%15 = apply_type(FillArrays.Fill, Float64, 1, Tuple{Base.OneTo{Int64}})::DataType
%16 = convert(Float64, %2)::Float64
%17 = convert(Tuple{Base.OneTo{Int64}}, %13)::Tuple{Base.OneTo{Int64}}
%18 = __new__(%15, %16, %17)::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
%19 = tuple(%18)::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
%20 = tuple(nothing)::Tuple{Nothing}
%21 = _apply_iterate(iterate, tuple, %20, %19)::Tuple{Nothing, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
%22 = getindex(%21, 2)::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
%23 = tuple(nothing, %22)::Tuple{Nothing, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
const %24 = tail::typeof(Base.tail)
%25 = %24(%23)::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
)