ROCm/rocWMMA

Clarification on Using fp16/bf16/half as Compute Type

Closed this issue · 2 comments

Hello,

While reviewing your documentation, I came across the mention that when executing wmma, the compute type can be set to fp16/bf16/half, which are described as "native f32 accumulations downcasted to fp16/bf16/half". I'm curious to understand whether utilizing these lower precision types in compute (which is not enabled in CUDA) can lead to performance improvements. Additionally, could you guide me on how to set this compute type if it can help performance?

Thank you!

Hello @xinyi-li7 ,
Thanks for the question!

On some architectures, such as gfx9 - the multiply-accumulate for floating point types is always natively in fp32 format. This has a higher level of precision which is desirable in many cases to preserve accumulation accuracy. Setting the compute type in this case to half-precision is only an approximation, casting the native fp32 result back down to this size. This may afford some reduction in VREG usage, however at the cost of conversion between f32 and half-precision types which is not very desirable. For gfx9 architectures the highest performance for fp16/bf16/half would be using compute type of f32 to avoid the overhead.

Gfx11 architectures on the other hand - the multiply-accumulate for fp16/bf16 natively supports BOTH fp32 and fp16/bf16. Here there is no extra conversion costs to use half-precision compute type. In this case there could be performance gains over fp32 compute type, at the cost of reduced precision which can introduce some tricky numerical issues. At the end of the day - if your data can handle lower precision accumulation, then this would be recommended.

The datatype of your accumulator fragments controls the compute type (see API)

Hope this helps!

Hi @cgmillette,
I see. So the output type of accumulator is essentially the compute type. I'm now using MI100 and MI250X so I guess there's no native lower precision supports. Thank you so much for your patient reply!