FluxML/ZygoteRules.jl

Adjoint rule for if adjoint doesn't exist?

Closed this issue · 1 comments

I want to do an adjoint definition like:

ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t)
  if f.vjp === nothing
    ZygoteRules.adjoint(f.f,u,p,t)
  else
    f.vjp(u,p,t)
  end
end

If I directly depended on Zygote I could directly call the pullback, but here all I have are the adjoints. What do I return to mean "there is no definition"?

You can more or less use exactly what you've written, with ZygoteRules._pullback in place of ZygoteRules.adjoint. You have to be a bit careful with it since @adjoint itself defines an overload for that function, so if you call it with the same arguments it can stack overflow. But in your case you're unwrapping f to something that can be AD'd, so it's fine.