Using Bfloat16 GEMM from OpenBlas
skykongkong8 opened this issue · 2 comments
skykongkong8 commented
It seems latest OpenBlas supports bfloat16 GEMM.
I guess upgrading openblas version from here will simply bring their functions to the NNTrainer (0.3.18 -> 0.3.24)
1. hardware compatibility
- Uses a new dataType
bfloat16
, but defined as uint16_t. - For ARM, it uses bfloat16 NEON intrinsics, and it seems they are supported from armv8.6 with additional extensions.
- They have all the GEMM kernels implemented for arm64 / x86_64, but what I am not sure is that their kernels are named as "neoversen2", which is Armv9.0-A CPU.. Currently checking for more kernels.
- According to their issues, it is claimed that their bfloat16 GEMM kernels are optimized on specific CPUs like : neoverse-n2, cooperlake. Thus we might consider what processor are we going to use when trying to test this.
2. accuracy
- bfloat16 GEMM output is float, and when compared to casted input computation result, it seems their output is almost the same.
// 20240704
// CC : bfloat16 output, C : float output
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
ret++;
...
// DD : bfloat16->float casted input data ->GEMM output
if (CC[i * m + j] != DD[i * m + j])
ret++;
...
if (ret != 0)
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
...
- Interesting thing is that they are comparing float VS bfloat16 GEMM output with
epsilon=1.0
even with small GEMM examples like (100,100) x (100x100) problem. - Full Test code
3. latency measurement
- WTD
4. note
Bfloat16 is more robust to inf / NaN, which can be useful to mixed precision training and fp16fp32 accumulation.
skykongkong8 commented
It seems introducing this function in the current NNTrainer is not appropriate, and almost every participant in the project is aware of this issue.
This issue can be raised again when the proper moment to introduce such functions has come.