Make a helper function to squash/unsquash all axes (except some) into a single batch axis
dlwh opened this issue · 1 comments
dlwh commented
lots of JAX stuff works with only a single batch axis, and other custom kernels are easier to write if there's a single batch axis. On TPU, this is basically free (assuming the batch axes aren't at the beginning)
dlwh commented
fixed in dev. It's called haliax.core.flatten_all_axes_but