Output scale not being used with `te_gemm` in FP8
snarayan21 opened this issue · 3 comments
Hey, I'm using the te_gemm
function defined in the PyTorch extensions here, and I'm trying to apply a scaling factor to the output. My gemm inputs are in fp8e4m3 and the output is in bf16.
For the D_scale
argument, I am passing in tensors like torch.tensor([4.0], device='cuda')
but changing the value of the scaling factor has no impact on the output. Am I doing something wrong here? Or is the scaling factor only applied when the output is of a certain dtype?
Hello @snarayan21. Yes, the purpose of this argument is to be the scaling factor for the output when the operator is producing FP8. The cuBLAS API does not otherwise have a parameter for scaling of the output. There is a potentially close thing in alpha
: cublas performs D = alpha * A * B + beta * C
. Currently in te_gemm
we do not give a way to specify alpha (and beta is set to 1 by the accumulate
option which performs D = A * B + D
).
What is the usecase you are interested in that would benefit from that scaling?
@ptrendx Oh I see. I think that since the model I have doesn't use bias terms, it would be nice to just specify alpha
...but would I get the same result by just modifying the A_scale_inverse
in the te_gemm function? As in, I could multiply the existing A_scale_inverse
by alpha to get what I want? The model has a few scaling factors that I would like to fuse into the FP8 gemms, if possible.
Yes, if you don't have bias that would be possible - you just need to be careful to not overwrite the scale inverse given as an input and instead create a new tensor there to pass to the gemm.