data-apis/array-api

`unique_inverse` diverges from NumPy

seberg opened this issue · 1 comments

We need to revert the changes to unique for it's inverse IMO. The motivation for the choice of unique_inverse was never quite correct:
Yes, reconstruction was impossible with axis=None with a 1-D inverse array. However, this issue was exclusive for axis=None while the choice here is actively less helpful for all other cases.

The choice in the array API makes the reconstruction:

  • Use take(unique_vals, unique_inverse) if axis=None
  • Use take_along_axis(unique_vals, unique_inverse, axis=axis) otherwise.

This is not helpful, because axis= in unique doesn't specify core axis (like for sum, argmax, etc.). It rather specifies the non-core axis!
In more clear terms:

  • np.take(unique_vals, unique_inverse, axis=axis) already always works if you only change the value for the axis=None case.
  • To clarify, all axes except axis are of size 1 currently in unique_inverse. This isn't what you get for any "typical" take_along_axis use-case where the dimensions match exactly except along axis. (Yes, take_along_axis can allow broadcasting, but it's purpose is exactly that you don't broadcast all axes except axis.)

I am not sure I have the energy to push against the choice more, but I am very much convinced that the choice is a mistake for unique (except trying to make axis=None more useful). I am also convinced that NumPy must revert the choice (and this was discussed in a NumPy meeeting).

It still seems preferable to not diverge, and I think that means that this should be fixed here (and wherever it was already implemented). IMO, the choice was a mistake based on a wrong concept of axis= here. And I was mistaken in assuming that it was made to be the only useful choice that allows simple reconstruction in all cases, when I thought NumPy should align it in this direction, rather than choosing the proper middle ground.

xref numpy/numpy#26914

I thought this came from here, because that was the reason for the change in NumPy, but it doesn't quite come from here, because axis= doesn't exist here (which is the problematic part, I have no squirms about changing axis=None).