[Feature Request] Need Matmul Attention layer instead of Einsum to support GPU running
MoFHeka opened this issue · 4 comments
Einsum kernel in Praxis couldn't' be lowered to cudnn GEMM. The computing performance is seriously affected. Jax version Attention layer much slower than Tensorflow version.
Why do you say cudnn GEMM isn't used? Normally it should. Can you provide an example where cudnn gemm isn't used in such a case?
And why did you do this here and not in the TransformerEngine repo?
Did I miss-understood the request?
@nouiz Yes, TransformerEngine did use cudnn GEMM.
But JAX(Flax or Praxis) attention layers was constructed by Einsum kernels, which couldn't' be lowered to cudnn GEMM and the latest cudnn XLA FMHA kernel. When running attention layers in GPU, it could be only transformed to triton kernel according the XLA dump log...
TE currently only supports a limited number of transformer models (such as MOE is difficult to support) and does not yet support LORA SFT. So it may be necessary to optimize the layer composition of the Jax ecosystem.
Sorry, I'm not sure where to put the requirement because it doesn't look like the TE team should be responsible for it. As I understand it, the TE team is only responsible for the 'custom_call' part of jax.
I think you will be interested by this PR: jax-ml/jax#18814
@nouiz Cool, thank you!
It would be nice if someone could also change the code in jax wrapper components like praxis, flax, etc., since they are now written in einsum.
Because Jax-Toolbox use paxml, and paxml use praxis. But praxis was written without matmul regardless what kernel generation in Jax core.