dfdx/Ghost.jl

Tracing Zygote.gradient

jw3126 opened this issue · 1 comments

Tracing Zygote.gradient does not work currently:

julia> module MWE
       using Zygote, Ghost
       f(x) = sum(x)
       x = [1]
       Zygote.gradient(f, x) # works
       trace(Zygote.gradient,f, x)
       end#module

ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Ghost.Variabl
e
Closest candidates are:
  convert(::Type{T}, ::T) where T at essentials.jl:205
  Ghost.Variable(::Union{Nothing, Integer}, ::Union{Nothing, Ghost.AbstractOp}) at /home/jan/.jul
ia/packages/Ghost/1rVCi/src/tape.jl:22
  Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/1rVCi/src/tape.jl:22
Stacktrace:
  [1] setproperty!(x::Ghost.Frame, f::Symbol, v::Nothing)
    @ Base ./Base.jl:34
  [2] set_return!(t::Ghost.IRTracer, arg_sid_ref::Base.RefValue{GlobalRef})
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:188
  [3] (::Ghost.IRTracer)(::Zygote.ZBack{ChainRules.var"#isempty_pullback#446"}, ::Nothing)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:114
  [4] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::Function, ::Vararg{Any, N} w
here N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
  [5] IRTracer
    @ ./reduce.jl:501 [inlined]
  [6] (::Ghost.IRTracer)(::typeof((#sum#221)), ::Int64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
  [7] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof((#sum#221)), ::Varar
g{Any, N} where N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
  [8] IRTracer
    @ ./reduce.jl:501 [inlined]
  [9] (::Ghost.IRTracer)(::typeof((sum)), ::Int64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [10] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof((sum)), ::Vararg{Any
, N} where N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
 [11] IRTracer
    @ ./reduce.jl:528 [inlined]
 [12] (::Ghost.IRTracer)(::typeof((#sum#222)), ::Int64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [13] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof((#sum#222)), ::Varar
g{Any, N} where N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
 [14] IRTracer
    @ ./reduce.jl:528 [inlined]
 [15] (::Ghost.IRTracer)(::typeof((sum)), ::Int64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [16] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof((sum)), ::Vararg{Any
, N} where N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
 [17] IRTracer
    @ ./REPL[1]:3 [inlined]
 [18] (::Ghost.IRTracer)(::typeof((f)), ::Int64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [19] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::typeof((f)), ::Vararg{Any,
N} where N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
 [20] IRTracer
    @ ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41 [inlined]
 [21] (::Ghost.IRTracer)(::Zygote.var"#46#47"{typeof((f))}, ::Int64)
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [22] record_or_recurse!(::Ghost.IRTracer, ::Int64, ::Vector{Any}, ::Function, ::Vararg{Any, N} w
here N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:453
 [23] IRTracer
    @ ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59 [inlined]
 [24] (::Ghost.IRTracer)(::typeof(Zygote.gradient), ::typeof(Main.MWE.f), ::Vector{Int64})
    @ Ghost ~/.julia/packages/IRTools/46viC/src/reflection/dynamo.jl:0
 [25] trace(::Function, ::Function, ::Vararg{Any, N} where N; is_primitive::Function, primitives:
:Nothing, ctx::Dict{Any, Any})
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:640
 [26] trace(::Function, ::Function, ::Vararg{Any, N} where N)
    @ Ghost ~/.julia/packages/Ghost/1rVCi/src/trace.jl:630
 [27] top-level scope
    @ REPL[1]:6
dfdx commented

This is also fixed in #14.

However note that in this example you trace Zygote.gradient() itself, not the derivative function created by Zygote.gradient(). Perhaps you were looking for Zygote.pullback()?

val, pb = Zygote.pullback(f, x)
trace(pb, 1.0)   # 1.0 - value of the derivative 

# output

((1-element Fill{Float64}: entries equal to 1.0,), Tape{Dict{Any, Any}}
  inp %1::Zygote.var"#46#47"{typeof((f))}
  inp %2::Float64
  %3 = getfield(%1, back)::typeof((f))
  const %4 = nothing::Nothing
  %5 = getfield(%3, t)::Tuple{Zygote.var"#2675#back#615"{Zygote.var"#611#613"{Vector{Int64}}}}
  %6 = getindex(%5, 1)::Zygote.var"#2675#back#615"{Zygote.var"#611#613"{Vector{Int64}}}
  %7 = getfield(%6, #2674#_back)::Zygote.var"#611#613"{Vector{Int64}}
  %8 = getfield(%7, xs)::Vector{Int64}
  %9 = size(%8)::Tuple{Int64}
  %10 = apply_type(FillArrays.Fill, Float64, 1)::UnionAll
  %11 = apply_type(FillArrays.Fill, Float64, 1)::UnionAll
  %12 = broadcasted(oneto, %9)::Broadcasted{}
  %13 = materialize(%12)::Tuple{Base.OneTo{Int64}}
  %14 = apply_type(FillArrays.Fill, Float64, 1, Tuple{Base.OneTo{Int64}})::DataType
  %15 = apply_type(FillArrays.Fill, Float64, 1, Tuple{Base.OneTo{Int64}})::DataType
  %16 = convert(Float64, %2)::Float64
  %17 = convert(Tuple{Base.OneTo{Int64}}, %13)::Tuple{Base.OneTo{Int64}}
  %18 = __new__(%15, %16, %17)::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
  %19 = tuple(%18)::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
  %20 = tuple(nothing)::Tuple{Nothing}
  %21 = _apply_iterate(iterate, tuple, %20, %19)::Tuple{Nothing, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
  %22 = getindex(%21, 2)::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
  %23 = tuple(nothing, %22)::Tuple{Nothing, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
  const %24 = tail::typeof(Base.tail)
  %25 = %24(%23)::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
)