google-deepmind/alphatensor

What is the precision during computing?

beginlner opened this issue · 1 comments

Input and output matrix is float32, but it seems like using bfloat16 during computing (by jax documentations). How can I change to full precision mode?

You can see the last test in benchmarking/test_correctness.py, including an example of configuring the precision.