data-apis/array-api

result_type() for mixed arrays/Python scalars

shoyer opened this issue · 11 comments

The array API's type promotion rules support mixed scalar/array operations, e.g., 1 + xp.arange(3).

For Xarray, we would like to be able to figure out the resulting dtype from this sort of operation before actually doing it (pydata/xarray#8946).

Ideally, we could use xp.result_type() for this purpose, but as documented result_type only supports arrays and dtype objects. Could we potentially extend result_type to also handle Python scalars? It is worth noting that this already works today in NumPy, e.g.,

>>> np.result_type(1, np.arange(3))
dtype('int64')

This makes sense to me. torch seems to support this as well. What should the result be if there are multiple Python scalars? Undefined?

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

In most cases I imagine array libraries will have a default dtype, but different libraries will make different choices (e.g., int32 in JAX vs int64 in NumPy):

>> np.result_type(1, 2)
dtype('int64')
>> jnp.result_type(1, 2)
dtype('int32')

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

One concern I see with this is that libraries need not support Python scalars in functions, only for operators. So result_type(a, b) working does not imply that func(a, b) will work.

In Xarrray, we are thinking of defining something like:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace(scalars_or_arrays)
    dtype = xp.result_type(*scalars_or_arrays)
    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

Does xarray automatically call asarray on scalar function arguments like NumPy does? Certainly the recommendation of the standard is to not do that, because it's cleaner from a typing perspective. Implicitly calling asarray at the top of every function is considered a historical NumPy antipattern. It's not disallowed, but we also should probably avoid standardizing things that encourage it.

the only time we call that function is when preparing arguments for where (and for concat / stack, but there we don't expect to encounter python scalars), which as far as I can tell doesn't support python scalars.

Xarray objects always contain array objects, but indeed there are functions like where() for which it's convenient to be able to use scalars.

I opened a separate issue to discuss: #807

This sounds like a useful change to me.

What should the result be if there are multiple Python scalars? Undefined?

This should indeed probably be undefined by the spec.

What is the problem? It seems well-defined to allow multiple. If multiple arrays and dtype objects are allowed, why not multiple Python scalars?

I'm not sure, but I think that was referring to a situation where you have no explicit dtypes, just (compatible) python scalars. In that case, we'd have to make an arbitrary choice (or raise an error).

Ah of course. Agreed, there must be at least one array or dtype object.

Making this change to result_type seemed fair to everyone in the discussion we just had. Given that our type promotion rules include Python scalars, the function that can be used to apply those promotion rules should support them as well.