gpu_fast_attention not passing segment_ids to jax pallas attention mha
Cjkkkk opened this issue · 0 comments
Cjkkkk commented
praxis/praxis/layers/gpu_fast_attention.py
Line 133 in 43db271
https://github.com/google/jax/blame/main/jax/experimental/pallas/ops/attention.py#L163
seems like jax has added
segment_ids
as required argument but praxis has not updated to add the argument