jax-ml/ml_dtypes

`data` attribute raises a `ValueError`

Opened this issue · 0 comments

numpy.ndarray.data attribute doesn't work with ml_dtypes.

bfloat16 example:

from ml_dtypes import bfloat16
x = np.array([0], dtype=bfloat16)
x.data
# ValueError: cannot include dtype 'E' in a buffer

float8_e4m3fnuz example:

from ml_dtypes import float8_e4m3fnuz
x = np.array([0], dtype=float8_e4m3fnuz)
x.data
# ValueError: cannot include dtype 'G' in a buffer

Current workaround is by using the __array_interface__ attribute:

x.__array_interface__['data'][0]

ml_dtypes version: 0.2.0
numpy version: 1.24.3