alxndrTL/mamba.py

Why use element-wise multiplication rather than matrix multiplication in the function `selective_scan_seq`

cszhbo opened this issue · 2 comments

Hello. In the function selective_scan_seq, there are two points that I am confused:

  • BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
  • h = deltaA[:, t] * h + BX[:, t]
    These two lines of code seem to be element-wise multiplication.

However, in the paper, the equation is
$$h_t = \bar{A} h_{t-1} + \bar{B} x_{t}$$
Both terms in the right side of the euation is performed in matrix multiplication.

I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?

Hello,
regarding the first multiplication $\bar{A} h_{t-1}$, we indeed have $\bar{A}$ which is $(N,N)$ and $h_{t-1}$ which is $(N,)$ so computing the multiplication would be done with a matmul. But it is assumed that $\bar{A}$ is kept diagonal throughout training, and so in the code we only represent it as a $(N,)$ vector (elements on the diagonal). And thus doing an element-wise multiplication with this $A$ represented as a vector with $h_{t-1}$ gives the correct result.

Concerning the second multiplication $\bar{B} x_t$, we have $\bar{B}$ which is a $(N,)$ vector but $x_t$ is only a scalar, so in the code we simply use an element-wise multiplication (with an unsqueeze to avoid any surprising broadcasting).

Note that in the code, in the two lines you gave, the equation $\bar{A} h_{t-1} + \bar{B} x_t$ is computed for the E*D channels in parallel, but that does not change the math behind it : if you understand it with a single channel (as explained just above in my answer) then PyTorch just does the batch multiplication for you.

Hope that is clear enough!

Ok, I see. Thanks for your detailed and clear explanation.