It is a pytorch implementation of self-attention or transformer from scratch. In this implementation, each word is predicted rather than each character.
The image above shows the architecture of the transformer. It esssentailly consists of two blocks:
- Encoder
- Decoder
It is used to encode all the time information into a single vector. If we look closely at the block, the encoder consists of these sections:
- Input Encoding: The entire training and validation dataset is used to create a dictionary of words and a unique ID is assigned to each word. This was achieved by using tokenizer package from hugging face. This unique ID is then converted to to feature space of size 'd_model'. The model learns this conversion during training. This way a word say 'a' becomes:
- Positional Encoding: 'sin' and 'cos' positional encoding are added to the input encoding before starting the inference/training. This encoding is constant and not learnt and is used to make the model understand the position of each word in a sentence.
- Self-attention: One way to understand iis that, correlation scores are calculated between the input words, i.e. how much is a word related to another word in the input. To calculate the score, following formula is used:
where K, Q, V represents Key, Query and Value
The decoder block is very similar to the encoder block. But something to note in the decoder block is that it not only uses the encoded output consisting of information of features of previous time steps, it also uses the outputs till n-1 time steps (also called output shifted right). In the decoder, the positional encoding is applied along with the self-attention block to find the relation even among the outputs. Something extra to note here is that along with self-attention, there is also cross attention, i.e. realtion between the outputs and the inputs (self-attention is only between the inputs). And this is where the understanding of query, key and value becomes important.
For the sake of understanding, let's assume that that query, key and value are the same values, i.e. inputs converted to input embeddings and then positional encoding is added to them. Let's call it $ Z $. Looking at the formula for attention, it becomes:
where @ means dot product
The dot product signifies the projection of one vector over another vector:
where
Let's represent two words, a
and b
as encoded vectors:
To perform the dot product in matrix format, we will have to take the transpose of one of the vectors:
Thus, we know that the dot product of the words in their embedding form gives us the projection of one word over the other, i.e., how much one word is affected by another. Now let's put all the words in embedding format in one vector and perform the dot product as above:
This way, we get the effect of each word on another word, including itself. A softmax is applied in the horizontal direction such that the effects become weighted. This way nth row represents the weight of every word on the nth word. Let's represent the attention calculates as:
Note: Sum of values along each row is 1
Then another dot product is taken between attention and
If looked closely, we can see that each feature of the vector is modified as per the attention calculated. For example,
During implementation the embedding vector
- Query
$(\vec{Q})$ : Represents the word for which we are calculating attention scores. - Key
$(\vec{K})$ : Represents the word against which the attention scores are calculated. - Value
$(\vec{V})$ : Represents the word's contribution to the output based on the attention scores.
In this implementation, we have used multi-headed attention. It enhances the model's ability to focus on different parts of the input sequence simultaneously. Instead of having a single set of query, key, and value weight matrices, multi-headed attention employs multiple sets, or "heads," each with its own set of weights. Each head performs its own attention operation, allowing the model to capture various aspects of relationships and dependencies in the data. The outputs of all heads are then concatenated and linearly transformed to produce the final attention output. This approach enables the model to aggregate diverse types of information and understand complex patterns more effectively.