use @adjoint! instead of @adjoint for einsum! AD
GiggleLiu opened this issue · 3 comments
According to the discussion in
#17 (comment)
There are some examples in Zygote source code about defining @adjoint!
.
Since how gradients are accumulated to the gradient tensor in mutable structures is a bit tricky, we will revisit this issue later (after dispatch
).
This is the Zygote paper,
https://arxiv.org/abs/1810.07951
In that comment, Mike says not to define adjoints for mutating functions so I think we should use regular @adjoint
and just define it for einsum
(without the !
).
At least for the moment, I can imagine progress on Zygotes side in the next weeks
ok, I see. If I get correctly, he means neither @adjoint
nor @adjoint!
are safe for inplace functions currently. Which means abandoning differentiating over broadcasting operations that have to be inplace.
I would suggest this kind of design to support a full featured AD
einsum(ixs, xs, iy, y_shape) = einsum!... # define adjoint on this, this is full featured, whereas no inplace operations.
einsum(ixs, xs, iy) = einsum(ixs, xs, iy, generate_output_shape(ixs, iy))
Cool?
Works for me.