The complex math behind transformer models, in simple words
It is no secret that transformer architecture was a breakthrough in the field of Natural Language Processing (NLP). It overcame the limitation of seq-to-seq models like RNNs, etc for being incapable of capturing long-term dependencies in text. The transformer architecture turned out to be the foundation stone of revolutionary architectures like BERT, GPT, and T5 and their variants. As many say, NLP is in the midst of a golden era and it wouldn’t be wrong to say that the transformer model is where it all started.
Need for Transformer Architecture
As said, necessity is the mother of invention. The traditional seq-to-seq models were no good when it came to working with long texts. Meaning the model tends to forget the learnings from the earlier parts of the input sequence as it moves to process the latter part of the input sequence. This loss of information is undesirable.
Although gated architectures like LSTMs and GRUs showed some improvement in performance for handling long-term dependencies by discarding information that was useless along the way to remember important information, it still wasn’t enough. The world needed something more powerful and in 2015, “attention mechanisms” were introduced by Bahdanau et al. They were used in combination with RNN/LSTM to mimic human behaviour to focus on selective things while ignoring the rest. Bahdanau suggested assigning relative importance to each word in a sentence so that model focuses on important words and ignores the rest. It emerged to be a massive improvement over encoder-decoder models for neural machine translation tasks and soon enough, the application of the attention mechanism was rolled out in other tasks as well.
The Era of Transformer Models
The transformer models are entirely based on an attention mechanism which is also known as “self-attention”. This architecture was introduced to the world in the paper “Attention is All You Need” in 2017. It consisted of an Encoder-Decoder Architecture.
On a high level,
- The encoder is responsible for accepting the input sentence and converting it into a hidden representation with all useless information discarded.
- The decoder accepts this hidden representation and tries to generate the target sentence.
In this article, we will delve into a detailed breakdown of the Encoder component of the Transformer model. In the next article, we shall look at the Decoder component in detail. Let’s start!
The encoder block of the transformer consists of a stack of N encoders that work sequentially. The output of one encoder is the input for the next encoder and so on. The output of the last encoder is the final representation of the input sentence that is fed to the decoder block.
Each encoder block can be further split into two components as shown in the figure below.
Let us look into each of these components one by one in detail to understand how the encoder block is working. The first component in the encoder block is multi-head attention but before we hop into the details, let us first understand an underlying concept: self-attention.
Self-Attention Mechanism
The first question that might pop up in everyone’s mind: Are attention and self-attention different concepts? Yes, they are. (Duh!)
Traditionally, the attention mechanisms came into existence for the task of neural machine translation as discussed in the previous section. So essentially the attention mechanism was applied to map the source and target sentence. As the seq-to-seq models perform the translation task token by token, the attention mechanism helps us to identify which token(s) from the source sentence to focus more on while generating token x for the target sentence. For this, it makes use of hidden state representations from encoders and decoders to calculate the attention scores and generate context vectors based on these scores as input for the decoder. If you wish to learn more about the Attention Mechanism, please hop on to this article (Brilliantly explained!).
Coming back to self-attention, the main idea is to calculate the attention scores while mapping the source sentence to itself. If you have a sentence like,
“The boy did not cross the road because it was too wide.”
It is easy for us humans to understand that word “it” refers to “road” in the above sentence but how do we make our language model understand this relationship as well? That’s where self-attention comes into the picture!
On a high level, every word in the sentence is compared against every other word in the sentence to quantify the relationships and understand the context. For representational purposes, you can refer to the figure below.
Let us see in detail how this self-attention is calculated (in real).
- Generate embeddings for the input sentence
Find embeddings of all the words and convert them into an input matrix. These embeddings can be generated via simple tokenisation and one-hot encoding or could be generated by embedding algorithms like BERT, etc. The dimension of the input matrix will be equal to the sentence length x embedding dimension. Let us call this input matrix X for future reference.
- Transform input matrix into Q, K & V
For calculating self-attention, we need to transform X (input matrix) into three new matrices:
– Query (Q)
– Key (K)
– Value (V)
To calculate these three matrices, we will randomly initialise three weight matrices namely Wq, Wk, & Wv. The input matrix X will be multiplied with these weight matrices Wq, Wk, & Wv to obtain values for Q, K & V respectively. The optimal values for weight matrices will be learned during the process to obtain more accurate values for Q, K & V.
- Calculate the dot product of Q and K-transpose
From the figure above, we can imply that qi, ki, and vi represent the values of Q, K, and V for the i-th word in the sentence.
The first row of the output matrix will tell you how word1 represented by q1 is related to the rest of the words in the sentence using dot-product. The higher the value of the dot-product, the more related the words are. For intuition of why this dot product was calculated, you can understand Q (query) and K (key) matrices in terms of information retrieval. So here,
– Q or Query = Term you are searching for
– K or Key = a set of keywords in your search engine against which Q is compared and matched.
As in the previous step, we are calculating the dot-product of two matrices i.e. performing a multiplication operation, there are chances that the value might explode. To make sure this does not happen and gradients are stabilised, we divide the dot product of Q and K-transpose by the square root of the embedding dimension (dk).
- Normalise the values using softmax
Normalisation using the softmax function will result in values between 0 and 1. The cells with high-scaled dot-product will be heightened furthermore whereas low values will be reduced making the distinction between matched word pairs clearer. The resultant output matrix can be regarded as a score matrix S.
- Calculate the attention matrix Z
The values matrix or V is multiplied by the score matrix S obtained from the previous step to calculate the attention matrix Z.
But wait, why multiply?
Suppose, Si = [0.9, 0.07, 0.03] is the score matrix value for i-th word from a sentence. This vector is multiplied with the V matrix to calculate Zi (attention matrix for i-th word).
Zi = [0.9 * V1 + 0.07 * V2 + 0.03 * V3]
Can we say that for understanding the context of i-th word, we should only focus on word1 (i.e. V1) as 90% of the value of attention score is coming from V1? We could clearly define the important words where more attention should be paid to understand the context of i-th word.
Hence, we can conclude that the higher the contribution of a word in the Zi representation, the more critical and related the words are to one another.
Now that we know how to calculate the self-attention matrix, let us understand the concept of the multi-head attention mechanism.
Multi-head attention Mechanism
What will happen if your score matrix is biased toward a specific word representation? It will mislead your model and the results will not be as accurate as we expect. Let us see an example to understand this better.
S1: “All is well”
Z(well) = 0.6 * V(all) + 0.0 * v(is) + 0.4 * V(well)
S2: “The dog ate the food because it was hungry”
Z(it) = 0.0 * V(the) + 1.0 * V(dog) + 0.0 * V(ate) + …… + 0.0 * V(hungry)
In S1 case, while calculating Z(well), more importance is given to V(all). It is even more than V(well) itself. There is no guarantee how accurate this will be.
In the S2 case, while calculating Z(it), all the importance is given to V(dog) whereas the scores for the rest of the words are 0.0 including V(it) as well. This looks acceptable as the “it” word is ambiguous. It makes sense to relate it more to another word than the word itself. That was the whole purpose of this exercise of calculating self-attention. To handle the context of ambiguous words in the input sentences.
In other words, we can say that if the current word is ambiguous then it is okay to give more importance to some other word while calculating self-attention but in other cases, it can be misleading for the model. So, what do we do now?
What if we calculate multiple attention matrices instead of calculating one attention matrix and derive the final attention matrix from these?
That is precisely what multi-head attention is all about! We calculate multiple versions of attention matrices z1, z2, z3, ….., zm and concatenate them to derive the final attention matrix. That way we can be more confident about our attention matrix.
Moving on to the next important concept,
Positional Encoding
In seq-to-seq models, the input sentence is fed word by word to the network which allows the model to track the positions of words relative to other words.
But in transformer models, we follow a different approach. Instead of giving inputs word by word, they are fed parallel-y which helps in reducing the training time and learning long-term dependency. But with this approach, the word order is lost. However, to understand the meaning of a sentence correctly, word order is extremely important. To overcome this problem, a new matrix called “positional encoding” (P) is introduced.
This matrix P is sent along with input matrix X to include the information related to the word order. For obvious reasons, the dimensions of X and P matrices are the same.
To calculate positional encoding, the formula given below is used.
In the above formula,
- pos = position of the word in the sentence
- d = dimension of the word/token embedding
- i = represents each dimension in the embedding
In calculations, d is fixed but pos and i vary. If d=512, then i ∈ [0, 255] as we take 2i.
This video covers positional encoding in-depth if you wish to know more about it.
Visual Guide to Transformer Neural Networks — (Part 1) Position Embeddings
I am using some visuals from the above video to explain this concept in my words.
The above figure shows an example of a positional encoding vector along with different variable values.
The above figure shows how the values of PE(pos, 2i) will vary if i is constant and only pos varies. As we know the sinusoidal wave is a periodic function that tends to repeat itself after a fixed interval. We can see that the encoding vectors for pos = 0 and pos = 6 are identical. This is not desirable as we would want different positional encoding vectors for different values of pos.
This can be achieved by varying the frequency of the sinusoidal wave.
As the value of i varies, the frequency of sinusoidal waves also varies resulting in different waves and hence, resulting in different values for every positional encoding vector. This is exactly what we wanted to achieve.
The positional encoding matrix (P) is added to the input matrix (X) and fed to the encoder.
The next component of the encoder is the feedforward network.
Feedforward Network
This sublayer in the encoder block is the classic neural network with two dense layers and ReLU activations. It accepts input from the multi-head attention layer, performs some non-linear transformations on the same and finally generates contextualised vectors. The fully-connected layer is responsible for considering each attention head and learning relevant information from them. Since the attention vectors are independent of each other, they can be passed to the transformer network in a parallelised way.
The last and final component of the Encoder block is Add & Norm component.
Add & Norm component
This is a residual layer followed by layer normalisation. The residual layer ensures that no important information related to the input of sub-layers is lost in the processing. While the normalisation layer promotes faster model training and prevents the values from changing heavily.
Within the encoder, there are two add & norm layers:
- connects the input of the multi-head attention sub-layer to its output
- connects the input of the feedforward network sub-layer to its output
With this, we conclude the internal working of the Encoders. To summarize the article, let us quickly go over the steps that the encoder uses:
- Generate embeddings or tokenized representations of the input sentence. This will be our input matrix X.
- Generate the positional embeddings to preserve the information related to the word order of the input sentence and add it to the input matrix X.
- Randomly initialize three matrices: Wq, Wk, & Wv i.e. weights of query, key & value. These weights will be updated during the training of the transformer model.
- Multiply the input matrix X with each of Wq, Wk, & Wv to generate Q (query), K (key) and V (value) matrices.
- Calculate the dot product of Q and K-transpose, scale the product by dividing it with the square root of dk or embedding dimension and finally normalize it using the softmax function.
- Calculate the attention matrix Z by multiplying the V or value matrix with the output of the softmax function.
- Pass this attention matrix to the feedforward network to perform non-linear transformations and generate contextualized embeddings.
In the next article, we will understand how the Decoder component of the Transformer model works.
This would be all for this article. I hope you found it useful. If you did, please don’t forget to clap and share it with your friends.