erfanzar/EasyDeL

Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap

nyl199310 opened this issue · 3 comments

Hi, I'm testing the attention mechanisms on kaggle TPU vm v3-8. it said below:

pallas_flash is Failed :

Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap.

Does that mean pallas attention will not support TPU v3? Thank you.

Hi and thanks for using easydel
Actually im creating that and mostly focusing on cpu and gpu so i forgot to test that on tpus ...
Ill fix that soon, and thanks for letting me know

i have fixed the issue related to shmap and xmap ..., but some custom kernels are still not supported or have incorrect computations in TPUv3, and pallas flash attention can not be runned on TPUv3 at least in current version of JAX 0.4.28

Ok, wish it can be supported in the future version. Thanks for your explanation!