What is the precision during computing?
beginlner opened this issue · 1 comments
beginlner commented
Input and output matrix is float32, but it seems like using bfloat16 during computing (by jax documentations). How can I change to full precision mode?
matejbalog commented
You can see the last test in benchmarking/test_correctness.py
, including an example of configuring the precision.