BUG: axlearn.common.utils.as_tensor calls .numpy() which doesn't work with bfloat16
explainerauthors opened this issue · 1 comments
I noticed this problem when I was developing Medusa+, which depends on axlearn. Here is the raw trace of the error when I am converting an Ajax tensor in bfloat16 to torch.tensor: https://rio.apple.com/projects/ai-medusa-plus/pipeline-specs/ai-medusa-plus-unit_tests/pipelines/f4345942-bddc-47ba-ba51-f3c1019290d3/log#L718-L738
The problem is that numpy does not support bfloat16, so if the source tensor is bfloat16, the call to .numpy()
would assert.
The following script reveals this problem:
import torch
import numpy
import jax.numpy as jnp
x = torch.rand(
(1,),
dtype=torch.float32,
)
print(x.numpy())
x = torch.rand(
(1,),
dtype=torch.float16,
)
print(x.numpy())
x = torch.rand(
(1,),
dtype=torch.bfloat16,
)
print(x.numpy())
The first calls to .numpy()
would succeed; however, the last would fail.
22:36 $ python3 medusa_plus/numpy_bf16.py
[0.686854]
[0.6177]
Traceback (most recent call last):
File "/mnt/medusa-plus/medusa_plus/numpy_bf16.py", line 21, in <module>
print(x.numpy())
TypeError: Got unsupported ScalarType BFloat16
It seems that the PyTorch community does not have a good solution to this pytorch/pytorch#90574. MLX is upcasting bfloat16 to float32. Let us close this issue.