under-Peter/OMEinsum.jl

use @adjoint! instead of @adjoint for einsum! AD

GiggleLiu opened this issue · 3 comments

According to the discussion in
#17 (comment)

There are some examples in Zygote source code about defining @adjoint!.
Since how gradients are accumulated to the gradient tensor in mutable structures is a bit tricky, we will revisit this issue later (after dispatch).

This is the Zygote paper,
https://arxiv.org/abs/1810.07951

In that comment, Mike says not to define adjoints for mutating functions so I think we should use regular @adjoint and just define it for einsum (without the !).

At least for the moment, I can imagine progress on Zygotes side in the next weeks

ok, I see. If I get correctly, he means neither @adjoint nor @adjoint! are safe for inplace functions currently. Which means abandoning differentiating over broadcasting operations that have to be inplace.

I would suggest this kind of design to support a full featured AD

einsum(ixs, xs, iy, y_shape) = einsum!...   # define adjoint on this, this is full featured, whereas no inplace operations.

einsum(ixs, xs, iy) = einsum(ixs, xs, iy, generate_output_shape(ixs, iy))

Cool?

Works for me.