nnstreamer/nntrainer

Using Bfloat16 GEMM from OpenBlas

skykongkong8 opened this issue · 2 comments

It seems latest OpenBlas supports bfloat16 GEMM.
I guess upgrading openblas version from here will simply bring their functions to the NNTrainer (0.3.18 -> 0.3.24)

1. hardware compatibility

  • Uses a new dataType bfloat16, but defined as uint16_t.
  • For ARM, it uses bfloat16 NEON intrinsics, and it seems they are supported from armv8.6 with additional extensions.
  • They have all the GEMM kernels implemented for arm64 / x86_64, but what I am not sure is that their kernels are named as "neoversen2", which is Armv9.0-A CPU.. Currently checking for more kernels.
  • According to their issues, it is claimed that their bfloat16 GEMM kernels are optimized on specific CPUs like : neoverse-n2, cooperlake. Thus we might consider what processor are we going to use when trying to test this.

2. accuracy

  • bfloat16 GEMM output is float, and when compared to casted input computation result, it seems their output is almost the same.
// 20240704
// CC : bfloat16 output, C : float output
 if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
	      ret++;
...
// DD : bfloat16->float casted input data ->GEMM output
if (CC[i * m + j] != DD[i * m + j])
		  ret++;
...
if (ret != 0)
    fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
...
  • Interesting thing is that they are comparing float VS bfloat16 GEMM output with epsilon=1.0 even with small GEMM examples like (100,100) x (100x100) problem.
  • Full Test code

3. latency measurement

  • WTD

4. note

Bfloat16 is more robust to inf / NaN, which can be useful to mixed precision training and fp16fp32 accumulation.

:octocat: cibot: Thank you for posting issue #2668. The person in charge will reply soon.

It seems introducing this function in the current NNTrainer is not appropriate, and almost every participant in the project is aware of this issue.
This issue can be raised again when the proper moment to introduce such functions has come.