google/praxis

gpu_fast_attention not passing segment_ids to jax pallas attention mha

Cjkkkk opened this issue · 0 comments

Cjkkkk commented

q, k, v, sm_scale=1.0 / math.sqrt(h), backward_pass_impl=bwd_pass_impl

https://github.com/google/jax/blame/main/jax/experimental/pallas/ops/attention.py#L163
seems like jax has added segment_ids as required argument but praxis has not updated to add the argument