jax-ml/ml_dtypes

numpy casting error: Cannot cast array data from dtype(bfloat16) to dtype(bfloat16) according to the rule 'unsafe'

idanori opened this issue · 5 comments

python 3.10.6
tensorflow==2.12,
jax == 0.4.25
ml-dtypes >= 0.3.2

each of the following casting will fail with TypeError:
Cannot cast array data from dtype(bfloat16) to dtype(bfloat16) according to the rule 'unsafe'

import numpy as np
import tensorflow as tf

const = tf.constant([1,2], dtype=tf.bfloat16)
const_numpy = const.numpy()
try:
	const_numpy.astype('bfloat16')
except TypeError as e:
	print(f"astype cast {const_numpy.dtype} to 'bfloat16' error: {e}")

try:
	const_numpy.astype(np.dtype('bfloat16'))
except TypeError as e:
	print(f"astype cast {const_numpy.dtype} to np.dtype('bfloat16') error: {e}")

try:
	np.asarray(const_numpy, 'bfloat16')
except TypeError as e:
	print(f"asarray cast {const_numpy.dtype} to 'bfloat16' error: {e}")
	
try:
	np.asarray(const_numpy, np.dtype('bfloat16'))
except TypeError as e:
	print(f"asarray cast {const_numpy.dtype} to np.dtype('bfloat16') error: {e}")

Hi - this looks like it's unrelated to ml_dtypes. Tensorflow version 2.12 bundled its own definitions of bfloat16 and other custom types; it wasn't until version 2.13 version 2.14 that tensorflow began depending on and using ml_dtypes. I'd suggest updating to a more recent tensorflow release.

tesnorflow 2.12 require jax>=0.3.15 which in turn require ml-dtypes>=0.2.0

from output of pip install tensorflow==2.12:
Collecting jax>=0.3.15 (from tensorflow==2.12)
Collecting ml-dtypes>=0.2.0 (from jax>=0.3.15->tensorflow==2.12)

I'm not sure why tensorflow 2.12 lists jax as a dependency, but regardless it's definitely the case that tensorflow 2.12 defines and registers its own bfloat16 type. I would raise an issue in the tensorflow repository (I suspect they will suggest updating to a more recent version).

If that is the case, how can it be explained that
pip install ml-dtypes == 0.3.1 solve the issue ?

If you're importing both a newer ml_dtypes and an older tensorflow in the same environment, then having two different bfloat16 declarations can cause issues. ml_dtypes 0.3.1 was released in an era when tensorflow still defined its own bfloat16 type, so it has code to check whether another package has registered bfloat16. We removed that in later versions of ml_dtypes once jax and tensorflow removed their own bfloat16 registrations.

The net result is that if you use an old tensorflow release in the same environment as a new ml_dtypes release, you'll run into compatibility issues like the one you're seeing. The solution is to not import an old tensorflow and a new ml_dtypes in the same environment.

My suggestion would be to update tensorflow, as that's the easiest fix here. Alternatively, you can install the versions of the tensorflow dependencies that were current when v2.12 was released, and you should not see these kinds of compatibility issues.