dfdx/Espresso.jl

Recover lowered code

dfdx opened this issue · 1 comments

dfdx commented

Example:

predict(W, b, x) = W * x .+ b
loss(W, b, x, y) = sum((predict(W, b, x) .- y)^2)
args, ex = funexpr(loss, (Matrix{Float64}, Vector{Float64}, Vector{Float64}, Vector{Float64}))

if all run from REPL, this gives expression:

:(begin 
        nothing
        return (Main.sum)((Base.literal_pow)(Main.^, (Base.broadcast)(Main.-, (Main.predict)(W, b, x), y), (Core.apply_type)(Base.Val, 2)))
    end)

We should add an additional step that recovers original code.

dfdx commented

Added a method recover_lowered (which is automatically called in funexor) to fix broadcasting, literal_pow and prefixed names, so the example above works just fine. However, lowered code also does different kinds of magic like operation fusion. For example, in this code:

foo(x) = x .^ 2 .+ 1

Julia automatically fuses .^ 2 and .+ 1 into a single SSAValue:

:(begin 
        nothing
        #361 = $(Expr(:new, :(Main.##361#362)))
        SSAValue(0) = #361
        return (Base.broadcast)(SSAValue(0), x)
    end)

Although it's possible to further inspect generated function, i.e.:

julia> funexpr(tf, (Vector{Float64},))
(Symbol[Symbol("#temp#"), Symbol("#temp#")], :(begin 
        #temp#@_3 = #temp#@_2 ^ 2
        #temp#@_3 + 1
    end))

I don't really believe these internals will stay stable enough to invest time into them. Defining functions in a separate file so that Sugar.jl can find its source is much more stable.