Why use element-wise multiplication rather than matrix multiplication in the function `selective_scan_seq`
cszhbo opened this issue · 2 comments
Hello. In the function selective_scan_seq
, there are two points that I am confused:
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
h = deltaA[:, t] * h + BX[:, t]
These two lines of code seem to be element-wise multiplication.
However, in the paper, the equation is
Both terms in the right side of the euation is performed in matrix multiplication.
I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?
Hello,
regarding the first multiplication
Concerning the second multiplication unsqueeze
to avoid any surprising broadcasting).
Note that in the code, in the two lines you gave, the equation E*D
channels in parallel, but that does not change the math behind it : if you understand it with a single channel (as explained just above in my answer) then PyTorch just does the batch multiplication for you.
Hope that is clear enough!
Ok, I see. Thanks for your detailed and clear explanation.