stanford-crfm/haliax

Make a helper function to squash/unsquash all axes (except some) into a single batch axis

dlwh opened this issue · 1 comments

dlwh commented

lots of JAX stuff works with only a single batch axis, and other custom kernels are easier to write if there's a single batch axis. On TPU, this is basically free (assuming the batch axes aren't at the beginning)

dlwh commented

fixed in dev. It's called haliax.core.flatten_all_axes_but