JuliaDiff/ForwardDiff.jl

Gradients with respect to struct fields?

Opened this issue · 2 comments

In Zygote.jl, we can take the gradient with respect to all fields of a struct foo passed through a function bar via

g = Zygote.gradient(f -> bar(f), foo)

Can this be done in ForwardDiff as well?

Reproducer:

using Zygote
using ForwardDiff

struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2, 3, 3e8)
println(foo)

g = Zygote.gradient(f -> bar(f), foo)

println(g)


g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)

Not straight forward. ForwardDiff differentiates w.r.t numbers and abstract vectors. You might be able to hack something together with generated functions.

If you want to do this yourself, and only have a struct of real numbers, then it will be fairly simple:

julia> using ForwardDiff: Dual, partials

julia> make_dual(z::Foo) = Foo(Dual(z.x,1,0,0), Dual(z.t,0,1,0), Dual(z.c,0,0,1));

julia> get_Foo(dy::Dual) = (; x=partials(dy,1), t=partials(dy,2), c=partials(dy,3));

julia> get_Foo(bar(make_dual(foo)))
(x = 1.0, t = -3.0e8, c = -3.0)

julia> Zygote.gradient(bar, foo)[1]
(x = 1.0, t = -3.0e8, c = -3.0)

With a bit more work you could automate this to work with many structs of numbers, struct_gradient(f, x). And even allow structs of structs.

Allowing structs containing arrays will be much more tricky, basically thanks to ForwardDiff's chunk mode.