jax-ml/ml_dtypes

Type promotion behavior with Python scalars changed with NumPy 2

Opened this issue · 0 comments

With NumPy 1.26.4

In [1]: import numpy as np, ml_dtypes

In [2]: np.arange(4, dtype=ml_dtypes.bfloat16) + 10
Out[2]: array([10, 11, 12, 13], dtype=bfloat16)

With NumPy 2.1.0

In [1]: import numpy as np, ml_dtypes

In [2]: np.arange(4, dtype=ml_dtypes.bfloat16) + 10
Out[2]: array([10., 11., 12., 13.], dtype=float32)

A promotion behavior change is not surprising, necessarily, but we'd actually expect the NumPy 1 behavior on NumPy 2, since the type of the scalar should be weak.