FlashAttention Support - TPUv3
maciek-pioro opened this issue · 1 comments
maciek-pioro commented
Is FlashAttention supported on TPUv3? The same config that works on TPUv4 fails on TPUv3 with the following error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.
However, after setting attention from autoselected
to dot_product
, the error disappears.
gobbleturk commented
Yes flash attention isn't supported on v3, you will had to use dot_product attention