Question about the attention calculation code "alpha = (query * key).sum(dim=-1) / scale"
muzi-peng opened this issue · 3 comments
Hi,
Could you tell me why the attention calculation in your code is achieved by doing hadamard product and summing the elements of the last dimension, instead of dot product operation?
Thank you so much!
Because (q * k).sum() is equivalent to q.T @ k
Oh, I get it. Thank you so much for your reply!
Because (q * k).sum() is equivalent to q.T @ k
Hi Dr. Zhou, thanks for your great work, but I do not understand why (q * k).sum() is equivalent to q.T @ k.
Here I have an example:q = torch.randn(2,3)
q
tensor([[-1.4198, -1.4788, -0.8260],
[-0.0783, 1.2059, 0.5165]])
k = torch.randn(2,3)
k
tensor([[-0.9287, -0.4349, 1.5053],
[ 1.0446, -1.4643, 0.6810]])
(q*k).sum(dim=-1)
tensor([ 0.7185, -1.4959])
q.T @ k
tensor([[ 1.2368, 0.7322, -2.1906],
[ 2.6331, -1.1226, -1.4048],
[ 1.3067, -0.3970, -0.8916]])
And obversourly, they are not equal to each other.