Allow comparing any numeric types in boolean functions
asmeurer opened this issue · 2 comments
Functions like equal, greater, and so on (and the operator equivalents) don't allow comparing non-promotable dtypes. This is particularly annoying because it makes it impossible to actually compare uint64 with int64, since the two cannot promote.
>>> import array_api_strict as xp
>>> xp.asarray(0, dtype=xp.int64) < xp.asarray(1, dtype=xp.uint64)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_array_object.py", line 717, in __lt__
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_array_object.py", line 179, in _check_allowed_dtypes
res_dtype = _result_type(self.dtype, other.dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/aaronmeurer/Documents/array-api-strict/array_api_strict/_dtypes.py", line 217, in _result_type
raise TypeError(f"{type1} and {type2} cannot be type promoted together")
TypeError: array_api_strict.int64 and array_api_strict.uint64 cannot be type promoted together
However, the standard doesn't actually say anywhere in greater
or __gt__
that the input types must be promotable:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater.html#greater
https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__gt__.html
just that they should be real numeric. So in principle, these operators should even work when comparing floats and integers.
And equal
allows any data type https://data-apis.org/array-api/latest/API_specification/generated/array_api.equal.html#equal, https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__eq__.html
It might be good to get some clarification in the standard about this, for instance, on how ==
should behave for mixing certain dtype combinations.
Would be good to get some standard clarification for equals data-apis/array-api#819.
Although we can probably just fallback to what NumPy does for now. The only potential problem is pre-2.0 promotion behavior, which is another argument for making 2.0 a hard dependency #21. I also need to double check that NumPy 2.0 isn't internally promoting uint64 and int64 to float64, although if it is I doubt I can reasonably work around it.
The consortium decided this should remain undefined in the standard. So unless that changes, we should keep things the way they are here.