jadore801120/attention-is-all-you-need-pytorch

Question About Attention Score Computation Process & Intuition

rezhv opened this issue · 0 comments

rezhv commented

When it comes to transformers, the Query and Key matrices are what determine the attention scores. Here is a nice visual taken from Jay Alammar's blog post on transformers that illustrates how attention scores are computed:
self-attention_softmax

As you can see the attention score depends solely on qi and kj vectors multiplied with no additional parameters. However each of these two vectors are calculated through a linear layer which had the word embedding (+positional) of just 1 word as input. My question is: how can the network assign attention scores meaningfully if q and k are computed without looking at different parts of the sentence other than their corresponding word? How can the network produce k and q vectors that when multiplied represent a meaningful attention score if k and q are computed based on a single word embedding?

lets say I want to process this sentence:
The man ate the apple; It didn't taste good.

When calculating the attention scores for the word 'it', how would the model know to assign a higher attention score to 'apple' (it refers to the apple) than to 'man' or basically any other word? The model had no way of understanding the context of the sentence because q and k are calculated solely based on the embedding of one word and not the sentence as a whole. q for 'it' is computed from the apple's embedding and the same goes for k for 'apple'. The two vectors are then multiplied to get the attention score. wouldn't this mean that if the two words are present in a different sentence but with the same distance the attention score between the two would be identical in the second sentence?

What makes sense to me is the classic approach to attention models. Look at the following visual from Andrew NG's deep learning specialization.

eac4f222d9d468a0c29a71a3830a5c60-c5w3l08attentionmodel-3-638

Here the attention scores are calculated using the hidden states at that timestamp. The hidden states are calculated with FC layers in a bidirectional RNN. In other words a hidden state at a certain timestamp is influenced by the words that come after and before it, So it makes sense that the model is able to calculate attention scores there.