data-apis/array-api-strict

ufunc comparisons are not equivalent to arithmetic comparison operators

ev-br opened this issue · 5 comments

Consider

In [1]: import array_api_strict as xp

In [2]: M = 6

In [3]: n = xp.arange(0, M, dtype=xp.float64)

In [4]: n <= (M-1)/2.0
Out[4]: 
Array([ True,  True,  True, False, False,
       False], dtype=array_api_strict.bool)

In [5]: xp.less_equal(n, (M-1)/2.0)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[5], line 1
----> 1 xp.less_equal(n, (M-1)/2.0)

File ~/miniforge3/envs/scipy-dev/lib/python3.12/site-packages/array_api_strict/_elementwise_functions.py:539, in less_equal(x1, x2)
    533 def less_equal(x1: Array, x2: Array, /) -> Array:
    534     """
    535     Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
    536 
    537     See its docstring for more information.
    538     """
--> 539     if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
    540         raise TypeError("Only real numeric dtypes are allowed in less_equal")
    541     # Call result type here just to raise on disallowed type combinations

AttributeError: 'float' object has no attribute 'dtype'

Right now the standard only allows arrays in functions. This will be changing in the next release of the standard (see data-apis/array-api#807) at which point it will also be updated here. We can also start implementing a draft 2024.12 version here (but so far this has been lower priority so it hasn't happened yet).

Thanks! I can work a bit on the draft standard, if it's helping and not just stepping on your toes.

It won't be stepping on my toes. I think I'll be focusing on the compat library and the tests presently.

Regarding draft standard work, I've started it at #82, and there is also work at #78. But I don't actually plan to implement scalar support here until it is implemented in the standard itself.

Going to close this and open a separate issue to track scalar support. It's mainly just waiting on upstream standard implementation.