AI-Hypercomputer/maxtext

FlashAttention Support - TPUv3

maciek-pioro opened this issue · 1 comments

Is FlashAttention supported on TPUv3? The same config that works on TPUv4 fails on TPUv3 with the following error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.
However, after setting attention from autoselected to dot_product, the error disappears.

Yes flash attention isn't supported on v3, you will had to use dot_product attention