bfloat16s (at least) throw error when given a format specifiier
Opened this issue · 0 comments
colehaus commented
>>> import jax
>>> jax.__version__
'0.4.21'
>>> from jax.dtypes import bfloat16
>>> f"{bfloat16(1)]}"
'1'
>>> f"{bfloat16(1):.2f}"
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: Unknown format code 'f' for object of type 'str'
This is a little surprising to me and, in a larger context, the error message isn't very suggestive about what or where the error is.