Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap
nyl199310 opened this issue · 3 comments
nyl199310 commented
Hi, I'm testing the attention mechanisms on kaggle TPU vm v3-8. it said below:
pallas_flash is Failed :
Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap.
Does that mean pallas attention will not support TPU v3? Thank you.
erfanzar commented
Hi and thanks for using easydel
Actually im creating that and mostly focusing on cpu and gpu so i forgot to test that on tpus ...
Ill fix that soon, and thanks for letting me know
erfanzar commented
i have fixed the issue related to shmap and xmap ..., but some custom kernels are still not supported or have incorrect computations in TPUv3, and pallas flash attention can not be runned on TPUv3 at least in current version of JAX 0.4.28
nyl199310 commented
Ok, wish it can be supported in the future version. Thanks for your explanation!