yt-project/unyt

__array_function__ support

Closed this issue · 13 comments

I spent a little bit of time this weekend working on adding __array_function__ to unyt_array. The work is on the array-function branch on my fork.

Unfortunately I think it might take quite a bit of effort to get a working implementation. This is because right now if a type defines __array_function__ then you signal to numpy that your implementation will work for all numpy functions that support __array_function__ (afaics pretty much the whole public API, although I can't find a public listing of all functions I'd need to wrap so I need to double-check that by inspecting the numpy source code). If you implement __array_function__ but do not have wrappers for a function you can either return NotImplemented for that function, in which case users will see a TypeError when they try to use that function, or you can attempt to coerce the input data to ndarray. The latter is really what I'd like, since that would preserve backward compatibility, however it's very difficult because the APIs of all of the functions that support __array_function__ are so heterogenous, so there isn't a single easy heuristic one can apply to coerce arguments to ndarray.

@l-johnston if you'd like to take a crack at this starting from my branch please feel free. I'll comment here if I go back to working on this and systematically wrapping the whole numpy API.

If we do add __array_function__ support I think the resulting product would have to be unyt 3.0, it's too much of a change in API for downstream users for it not to be. The fact that we're changing the types of what comes out of operations in downstream code is also very concerning, I'm not sure what breakage would happen. We probably need to test downstream projects that depend on unyt to make sure that adding __array_function__ support doesn't cause a lot of breakage.

Falling back to coercing to ndarray is possible though not necessarily advisable. Dask does something like this in its __array_function__ method. The main thing you need is some sort of "nested coercion" function that handles nested containers of your array type.

In practice, I think something like the following could work pretty well:

def nested_coerce(args):
    if isinstance(args, dict):
        return {k: nested_coerce(v) for k, v in args.items()}
    elif isinstance(args, tuple):
        return tuple(nested_coerce(a) for a in args)
    elif isinstance(args, list):
        return [nested_coerce(a) for a in args]
    elif isinstance(args, MyArray):
        return np.asarray(args)
    else:
        return args

You can all this like args, kwargs = nested_coerce((args, kwargs)) inside your __array_function__ method. It won't handle completely arbitrary containers, but generally those aren't valid for use with NumPy, anyways.

My worry about that is all of the isinstance calls are going to introduce a lot of runtime overhead, which would be a regression relative to the current state of things in unyt.

Separate from that, I worry about this sort of API guessing code being brittle and breaking in unexpected cases.

No disagreement there!

I agree that this is a pretty brittle. I would say that any default fall-back to coercion probably falls in this category.

One of the unique features of __array_function__ relative to older NumPy protocols is that it doesn't include "convenient" fall-back behavior by default. This was very much one of the core design goals (though it does make this transition hard).

Astropy’s approach - calling __array_function__ on the superclass - might make it easier to bootstrap this:

https://github.com/astropy/astropy/blob/master/astropy/units/quantity.py#L1524

It’ll probably help to crib off their implementation, I bet a lot of the code would make sense in unyt as well.

Ah yes, I forgot this is easier with subclasses :)

@ngoldbaum I can work on the implementation following the Astropy example - set me as the Assignee so that we don't have duplicate efforts.

No need to assign any one person to this issue, I think there’s enough work to go around. Also be careful, you may spend a bunch of time working on an approach that I eventually decide is not the way to go. All that to say please check in here with updates and don’t go off on your own and come back with something totally finished but that I have some fundamental issue with.

For one, I’d prefer it if supporting this didn’t require 2000 new lines of code like it apparently did for astropy.

Also before committing to add __array_function__ support I want to make sure it can be done without causing lots of brokenness in downstream packages - particularly yt. Preferably I’d like to evaluate that without doing all the work of fully implementing it.

All that to say thank you for you enthusiasm but I’d like to do this with some caution and deliberation to gauge risk and avoid spending a lot of time and effort for naught.

To that end, I created a list of all the Numpy functions that utilize __array_function__. There are 294 functions. Do we want to first categorize them?
Numpy API.xlsx

I would double-check that against this list: numpy/numpy#15544 (comment).

Also excel files aren't useful for me, I don't have a copy of excel installed. Probably easier to stick to e.g. a text gist?

I can see several categories of functions:

  • Functions that work fine out of the box with unyt operands and don't need any wrapping at all from us so we just call __array_function__ on the superclass. Right now we don't have any wrappers for those in unyt but we still silently rely on them working. For example, np.add, which dispatches to the add ufunc and __array_ufunc__.
  • Functions that work fine as long as there's some extra unit checking on the inputs and outputs. There are several of those already in unyt.
  • There may also be functions that require custom re-implementations on our end. Right now unyt doesn't have any of these but astropy seems to.
  • Finally functions that aren't wrapped and don't fall into any of those categories because e.g. they're added in a version of numpy that doesn't exist yet.

Converted excel file to gist:
Numpy API.txt

Closing along with #200.