apple/axlearn

BUG: axlearn.common.utils.as_tensor calls .numpy() which doesn't work with bfloat16

explainerauthors opened this issue · 1 comments

I noticed this problem when I was developing Medusa+, which depends on axlearn. Here is the raw trace of the error when I am converting an Ajax tensor in bfloat16 to torch.tensor: https://rio.apple.com/projects/ai-medusa-plus/pipeline-specs/ai-medusa-plus-unit_tests/pipelines/f4345942-bddc-47ba-ba51-f3c1019290d3/log#L718-L738

The problem is that numpy does not support bfloat16, so if the source tensor is bfloat16, the call to .numpy() would assert.

The following script reveals this problem:

import torch
import numpy
import jax.numpy as jnp

x = torch.rand(
    (1,),
    dtype=torch.float32,
)
print(x.numpy())

x = torch.rand(
    (1,),
    dtype=torch.float16,
)
print(x.numpy())

x = torch.rand(
    (1,),
    dtype=torch.bfloat16,
)
print(x.numpy())

The first calls to .numpy() would succeed; however, the last would fail.

22:36 $ python3 medusa_plus/numpy_bf16.py
[0.686854]
[0.6177]
Traceback (most recent call last):
  File "/mnt/medusa-plus/medusa_plus/numpy_bf16.py", line 21, in <module>
    print(x.numpy())
TypeError: Got unsupported ScalarType BFloat16

It seems that the PyTorch community does not have a good solution to this pytorch/pytorch#90574. MLX is upcasting bfloat16 to float32. Let us close this issue.