jax-ml/ml_dtypes

bfloat16s (at least) throw error when given a format specifiier

Opened this issue · 0 comments

>>> import jax
>>> jax.__version__
'0.4.21'
>>> from jax.dtypes import bfloat16
>>> f"{bfloat16(1)]}"
'1'
>>> f"{bfloat16(1):.2f}"
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: Unknown format code 'f' for object of type 'str'

This is a little surprising to me and, in a larger context, the error message isn't very suggestive about what or where the error is.