Gradients with respect to struct fields?
Opened this issue · 2 comments
NAThompson commented
In Zygote.jl
, we can take the gradient with respect to all fields of a struct foo
passed through a function bar
via
g = Zygote.gradient(f -> bar(f), foo)
Can this be done in ForwardDiff
as well?
Reproducer:
using Zygote
using ForwardDiff
struct Foo
x::Number
t::Number
c::Number
end
function bar(f::Foo)
return f.x - f.c*f.t
end
foo = Foo(2, 3, 3e8)
println(foo)
g = Zygote.gradient(f -> bar(f), foo)
println(g)
g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)
KristofferC commented
Not straight forward. ForwardDiff differentiates w.r.t numbers and abstract vectors. You might be able to hack something together with generated functions.
mcabbott commented
If you want to do this yourself, and only have a struct of real numbers, then it will be fairly simple:
julia> using ForwardDiff: Dual, partials
julia> make_dual(z::Foo) = Foo(Dual(z.x,1,0,0), Dual(z.t,0,1,0), Dual(z.c,0,0,1));
julia> get_Foo(dy::Dual) = (; x=partials(dy,1), t=partials(dy,2), c=partials(dy,3));
julia> get_Foo(bar(make_dual(foo)))
(x = 1.0, t = -3.0e8, c = -3.0)
julia> Zygote.gradient(bar, foo)[1]
(x = 1.0, t = -3.0e8, c = -3.0)
With a bit more work you could automate this to work with many structs of numbers, struct_gradient(f, x)
. And even allow structs of structs.
Allowing structs containing arrays will be much more tricky, basically thanks to ForwardDiff's chunk mode.