Cleaner All Pairs Difference
Opened this issue · 3 comments
In torch and jax it is possible to perform an all pairs difference using a one liner black magic represented as follows:
dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
This is performed in the reference implementation of Mamba 2
While the aforementioned code is not human readable nor obvious what it is doing, it was not obvious how to represent the equivalent in Haliax due to a subset constraint, however a potential solution is below:
def test_all_pairs_difference():
H = Axis("H", 7)
W = Axis("W", 8)
D = Axis("D", 9)
T = Axis("T", 11)
named1 = hax.random.uniform(PRNGKey(0), (H, W, D, T))
# making sure this analogue works:
#dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
named1_diff = named1.broadcast_axis(hax.Axis("T2", 11)) - named1.rename({"T": "T2"})
named1_diff = named1_diff.rearrange((..., "T", "T2"))
assert named1_diff.axes == (H, W, D, T, Axis("T2", 11))
vanilla_diff = named1.array[:, :, :, :, None] - named1.array[:, :, :, None, :]
assert jnp.all(named1_diff.array == vanilla_diff)
This issue exists provide better support for this kind of operation.
what do you think about:
with hax.auto_broadcast():
named1_diff = named1 - named1.rename({"T": "T2"})
Basically the only thing stopping this from working is an explicit check I do to avoid accidentally combining arrays where one isn't a subset of the other.
The other thing I could do is relax the check to be "at least one overlapping axis"
I think it is certainly cleaner, but I wouldn't remove the explicit check, wouldn't it be better to explicitly disable the check?
meaning you like with hax.auto_broadcast
?