Transformers
A broad overview of Transformers research
The pace of research in deep learning has accelerated significantly in recent years, making it increasingly difficult to keep abreast of all the latest developments. Despite this, there is a particular direction of investigation that has garnered significant attention due to its demonstrated success across a diverse range of domains, including natural language processing, computer vision, and audio processing. This is due in large part to its highly adaptable architecture. The model is called Transformer, and it makes use of an array of mechanisms and techniques in the field (i.e., attention mechanisms). You can read more about the building blocks and their implementation along with multiple illustrations in the following articles:
This article provides more details about the attention mechanisms that I will be talking about throughout this article:
A comprehensive range of models has been explored based on the vanilla Transformer to date, which can broadly be broken down into three categories:
- Architectural modifications
- Pretraining methods
- Applications
Each category above contains several other sub-categories, which I will investigate thoroughly in the next sections. Fig. 2. illustrates the categories researchers have modified Transformers.
Self-attention plays an elemental role in Transformer, although, it suffers from two main disadvantages in practice [1].
- Complexity: As for long sequences, this module turns into a bottleneck since its computational complexity is O(T²·D).
- Structural prior: It does not tackle the structural bias of the inputs and requires additional mechanisms to be injected into the training data which later it can learn (i.e. learning the order information of the input sequences).
Therefore, researchers have explored various techniques to overcome these drawbacks.
- Sparse attention: This technique tries to lower the computation time and the memory requirements of the attention mechanism by taking a smaller portion of the inputs into account instead of the entire input sequence, producing a sparse matrix in contrast to a full matrix.
- Linearized attention: Disentangling the attention matrix using kernel feature maps, this method tries to compute the attention in the reverse order to reduce the resource requirements to linear complexity.
- Prototype and memory compression: This line of modification tries to decrease the queries and key-value pairs to achieve a smaller attention matrix which in turn reduces the time and computational complexity.
- Low-rank self-attention: By explicitly modeling the low-rank property of the self-attention matrix using parameterization or replacing it with a low-rank approximation tries to improve the performance of the transformer.
- Attention with prior: Leveraging the prior attention distribution from other sources, this approach, combines other attention distributions with the one obtained from the inputs.
- Modified multi-head mechanism: There are various ways to modify and improve the performance of the multi-head mechanism which can be categorized under this research direction.
3.1. Sparse attention
The standard self-attention mechanism in a transformer requires every token to attend to all other tokens. However, it has been observed that in many cases, the attention matrix is often very sparse, meaning that only a small number of tokens actually attend to each other [2]. This suggests that it is possible to reduce the computational complexity of the self-attention mechanism by limiting the number of query-key pairs that each query attends to. By only computing the similarity scores for pre-defined patterns of query-key pairs, it is possible to significantly reduce the amount of computation required without sacrificing performance.
In the un-normalized attention matrix Â, the -∞ items are not typically stored in memory in order to reduce the memory footprint. This is done to decrease the amount of memory required to implement the matrix, which can improve the efficiency and performance of the system.
We can map the attention matrix to a bipartite graph where the standard attention mechanism can be thought of as a complete bipartite graph, where each query receives information from all of the nodes in the memory and uses this information to update its representation. In this way, the attention mechanism allows each query to attend to all of the other nodes in the memory and incorporate their information into its representation. This allows the model to capture complex relationships and dependencies between the nodes in the memory. The sparse attention mechanism, on the other hand, can be thought of as a sparse graph. This means that not all of the nodes in the graph are connected, which can reduce the computational complexity of the system and improve its efficiency and performance. By limiting the number of connections between nodes, the sparse attention mechanism can still capture important relationships and dependencies, but with less computational overhead.
There are two main classes of approaches to sparse attention, based on the metrics used to determine the sparse connections between nodes [1]. These are position-based and content-based sparse attention.
3.1.1. Position-based sparse attention
In this type of attention, the connections in the attention matrix are limited according to predetermined patterns. They can be expressed as combinations of simpler patterns, which can be useful for understanding and analyzing the behavior of the attention mechanism.
3.1.1.1. Atomic sparse attention: There are five basic atomic sparse attention patterns that can be used to construct a variety of different sparse attention mechanisms that have different trade-offs between computational complexity and performance as shown in Fig. 4.
- Global attention: Global nodes can be used as an information hub across all other nodes that can attend to all other nodes in the sequence and vice versa as in Fig. 4 (a).
- Band attention (also sliding window attention or local attention): The relationships and dependencies between different parts of the data are often local rather than global. In the band attention, the attention matrix is a band matrix, with the queries only attending to a certain number of neighboring nodes on either side as shown in Fig. 4 (b).
- Dilated attention: Similar to how dilated convolutional neural networks (CNNs) can increase the receptive field without increasing computational complexity, it is possible to do the same with band attention by using a dilated window with gaps of dilation w_d >= 1, as shown in Fig. 4 (c). Also, it can be extended to strided attention where the dilation 𝑤 𝑑 is assumed to be a large value.
- Random attention: To improve the ability of the attention mechanism to capture non-local interactions, a few edges can be randomly sampled for each query, as depicted in Fig. 4 (d).
- Block local attention: The input sequence is segmented into several non-intersecting query blocks, each of which is associated with a local memory block. The queries within each query block only attend to the keys in the corresponding memory block, shown in 3(e).
3.1.1.2. Compound sparse attention: As illustrated in Fig. 5, many existing sparse attention mechanisms are composed of more than one of the atomic patterns described above.
3.1.1.3. Extended sparse attention: There are also other types of patterns that have been explored for specific data types. By way of example, BP-Transformer [3] uses a binary tree to capture a combination of global and local attention across the input sequence. Tokens are leaf nodes and the internal nodes are span nodes containing multiple tokens. Fig. 6 shows a number of extended sparse attention patterns.
3.1.2. Content-based sparse attention
In this approach, a sparse graph is constructed where the sparse connections are based on the inputs. It selects the keys that have high similarity scores with the given query. An efficient way to build this graph is to use Maximum Inner Product Search (MIPS) which finds the maximum dot-product between keys and the query without calculating all dot-products.
Routing Transformer [4] as shown in Fig. 7, equips the self-attention mechanism with a sparse routing module by using online k-means clustering to cluster keys and queries on the same centroid vectors. It isolates the queries to only attend keys within the same cluster. Reformer [5] uses locality-sensitive hashing (LSH) instead of dot-product attention to select keys and values for each query. It enables the queries to only attend to tokens within the same bucket which are derived from the queries and keys using LSH. Using the LSTM edge predictor, Sparse Adaptive Connection (SAC) [6] constructs a graph from the input sequence and achieves attention edges to enhance the tasks-specific performance by leveraging an adaptive sparse connection.
3.2. Linearized attention
The computational complexity of the dot-product attention mechanism (softmax(QK^⊤)V) increases quadratically with the spatiotemporal size (length) of the input. Therefore, it impedes its usage when exposed to large inputs such as videos, long sequences, or high-resolution images. By disentangling softmax(QK^⊤) to Q′ K′^⊤, the (Q′ K′^⊤ V) can be computed in reverse order, resulting in a linear complexity O(𝑇 ).
Assuming  = exp(QK^⊤) denotes an un-normalized attention matrix, where exp(.) is applied element-wise, Linearized attention is a technique that approximates the un-normalized attention matrix exp(QK^⊤) with 𝜙(Q) 𝜙(K)^⊤ where 𝜙 is a row-wise feature map. By applying this technique, we can do 𝜙(Q) (𝜙(K)^⊤ V) which is a linearized computation of an un-normalized attention matrix as illustrated in Fig. 8.
To achieve a deeper understanding of linearized attention, I will explore the formulation in vector form. I will examine the general form of attention in order to gain further insight.
In this context, sim(·, ·) is a scoring function that measures the similarity between input vectors. In the vanilla Transformer, the scoring function is the exponential of the inner product, exp(⟨·, ·⟩). A suitable selection for sim(·, ·) is a kernel function, K(x, y) = 𝜙(x)𝜙(y)^⊤ , which leads to further insights into the linearized attention.
in this formulation, the outer product of vectors is denoted by ⊗. Attention can be linearized by first computing the highlighted terms which allow the autoregressive models i.e. transformer decoders to run like RNNs.
Eq. 2 shows that it keeps a memory matrix by aggregating associations from outer products of (feature-mapped) keys and queries. It later retrieves it by multiplying the memory matrix with the feature-mapped query with proper normalization.
This approach consists of two foundational components:
- Feature map 𝜙 (·): the kernel feature map for each attention implementation (ex. 𝜙𝑖(x) = elu(𝑥 𝑖 )+1 proposed in Linear Transformer
- Aggregation rule: aggregating the associations {𝜙 (k)𝑗 ⊗ v𝑗} into the memory matrix by simple summation.
3.3. Query prototyping and memory compression
Aside from employing the utilization of sparse attention or kernel-based linearized attention, it is also feasible to mitigate the intricacy of attention through a decrease in the quantity of queries or key-value pairs, thereby resulting in the initiation of query prototypes and the implementation of memory compression techniques, respectively.
3.3.1. Attention with prototype queries: The implementation of Attention with Prototype Queries involves the utilization of a set of query prototypes as the primary basis for computing attention distributions. The model employs two distinct methodologies, either by copying the computed distributions to the positions occupied by the represented queries, or by filling those positions with discrete uniform distributions. The flow of computation in this process is depicted in Figure 9(a).
Clustered Attention, as described in [7], involves the aggregation of queries into several clusters, with attention distributions being computed for the centroids of these clusters. All queries within a cluster are assigned the attention distribution calculated for its corresponding centroid.
Informer, as outlined in [8], employs a methodology of explicit query sparsity measurement, derived from an approximation of the Kullback-Leibler divergence between the query’s attention distribution and the discrete uniform distribution, to select query prototypes. Attention distributions are then calculated only for the top-𝑢 queries as determined by the query sparsity measurement, with the remaining queries being assigned discrete uniform distributions.
3.3.2. Attention with compressed key-value memory: This technique reduces the complexity of the attention mechanism in the Transformer by reducing the number of key-value pairs before applying attention as shown in Fig. 9(b). This is achieved by compressing the key-value memory. The compressed memory is then used to compute attention scores. This technique can significantly reduce the computational cost of attention while maintaining good performance on various NLP tasks.
Liu et al. [9] suggest a technique called Memory Compressed Attention (MCA) in their paper. MCA involves using strided convolution to decrease the number of keys and values. MCA is utilized alongside local attention, which is also proposed in the same paper. By reducing the number of keys and values by a factor of the kernel size, MCA is able to capture global context and process longer sequences than the standard Transformer model with the same computational resources.
Set Transformer [10] and Luna [11] are two models that utilize external trainable global nodes to condense information from inputs. The condensed representations then function as a compressed memory that the inputs attend to, effectively reducing the quadratic complexity of self-attention to linear complexity concerning the length of the input sequence.
Linformer [12] reduces the computational complexity of self-attention to linear by linearly projecting keys and values from the length n to a smaller length n_k. The setback with this approach is the pre-assumed input sequence length, making it unsuitable for autoregressive attention.
Poolingformer [13] employs a two-level attention mechanism that combines sliding window attention with compressed memory attention. Compressed memory attention helps with enlarging the receptive field. To reduce the number of keys and values, several pooling operations are explored, including max pooling and Dynamic Convolution-based pooling.
3.4. Low-rank self-attention
According to empirical and theoretical analyses conducted by various researchers [14, 12], the self-attention matrix A ∈ R𝑇 ×𝑇 exhibits low-rank characteristics in many cases. This observation offers two implications: Firstly, the low-rank nature can be explicitly modeled using parameterization. This could lead to the development of new models that leverage this property to improve performance. Secondly, instead of using the full self-attention matrix, a low-rank approximation could be used in its place. This approach could enable more efficient computations and further enhance the scalability of self-attention-based models.
3.4.1. Low-rank parameterization: When the rank of the attention matrix is lower than the sequence length, it suggests that over-parameterizing the model by setting 𝐷𝑘 > 𝑇 would lead to overfitting in situations where the input is typically short. Therefore, it is sensible to restrict the dimension of 𝐷𝑘 and leverage the low-rank property as an inductive bias. To this end, Guo et al. [14] propose decomposing the self-attention matrix into a low-rank attention module with a small 𝐷𝑘 that captures long-range non-local interactions, and a band attention module that captures local dependencies. This approach can be beneficial in scenarios where the input is short and requires effective modeling of both local and non-local dependencies.
3.4.2. Low-rank approximation: The low-rank property of the attention matrix can also be leveraged to reduce the complexity of self-attention by using a low-rank matrix approximation. This methodology is closely related to the low-rank approximation of kernel matrices, and some existing works are inspired by kernel approximation. For instance, Performer, as discussed in Section 3.2, uses a random feature map originally proposed to approximate Gaussian kernels to decompose the attention distribution matrix A into C𝑄 GC𝐾, where G is a Gaussian kernel matrix and the random feature map approximates G.
An alternative approach to dealing with the low-rank property of attention matrices is to use Nyström-based methods [15, 16]. In these methods, a subset of landmark nodes is selected from the input sequence using down-sampling techniques such as strided average pooling. The selected landmarks are then used as queries and keys to approximate the attention matrix. Specifically, the attention computation involves softmax normalization of the product of the original queries with the selected keys, followed by the product of the selected queries with the normalized result. This can be expressed as:
Note that the inverse of the matrix M^-1 = (softmax(Q̃K̃^T))^-1 may not always exist, but this issue can be mitigated in various ways. For example, CSALR [15] adds an identity matrix to M to ensure the inverse always exists, while Nyström-former [16] uses the Moore-Penrose pseudoinverse of M to handle singular cases.
3.5. Attention with prior
The attention mechanism is a way of focusing on specific parts of an input sequence. It does this by generating a weighted sum of the vectors in the sequence, where the weights are determined by an attention distribution. The attention distribution can be generated from the inputs, or it can come from other sources, such as prior knowledge. In most cases, the attention distribution from the inputs and the prior attention distribution are combined by computing a weighted sum of their scores before applying softmax, thus, allowing the neural network to learn from both the inputs and the prior knowledge.
3.5.1. Prior that models locality: To model the locality of certain types of data like text, a Gaussian distribution over positions can be used as prior attention. This involves multiplying the generated attention distribution with a Gaussian density and renormalizing or adding a bias term G to the generated attention scores, where higher G indicates a higher prior probability of attending to a specific input.
Yang et al. [17] propose a method of predicting a central position for each input and defining the Gaussian bias accordingly:
where 𝜎 denotes the standard deviation for the Gaussian. The Gaussian bias is defined as the negative of the squared distance between the central position and the input position, divided by the standard deviation of the Gaussian distribution. The standard deviation can be determined as a hyperparameter or predicted from the inputs.
The Gaussian Transformer [18] model assumes that the central position for each input query 𝑞𝑖 is 𝑖, and defines the bias term 𝐺𝑖 𝑗 for the generated attention scores as
where 𝑤 is a non-negative scalar parameter controlling the deviation and 𝑏 is a negative scalar parameter reducing the weight for the central position.
3.5.2. Prior from lower modules: In Transformer architecture, attention distributions between adjacent layers are often found to be similar. Therefore, it is reasonable to use the attention distribution from a lower layer as a prior for computing attention in a higher layer. This can be achieved by combining the attention scores from the current layer with a weighted sum of the previous layer’s attention scores and a translation function that maps the previous scores to the prior to be applied.
where A(𝑙) represents the l-th layer attention scores while w1 and w2 control the relative importance of the previous attention scores and the current attention scores. Also, the function 𝑔: R𝑛×𝑛 → R𝑛×𝑛 translates the previous attention scores into a prior to be applied to the current attention scores.
The Predictive Attention Transformer proposed in the paper [19] suggests using a 2D-convolutional layer on the previous attention scores to compute the final attention scores as a convex combination of the generated attention scores and the convolved scores. In other words, the weight parameters for the generated and convolved scores are set to 𝛼 and 1-𝛼, respectively, and the function 𝑔(·) in Eq. (6) is a convolutional layer. The paper presents experiments showing that training the model from scratch and fine-tuning it after adapting a pre-trained BERT model both lead to improvements over baseline models.
The Realformer model proposed in [20] introduces a residual skip connection on attention maps by directly adding the previous attention scores to the newly generated ones. This can be seen as setting 𝑤 1 = 𝑤 2 = 1 and 𝑔(·) to be the identity map in Eq. (6). The authors conduct pre-training experiments on this model and report that it outperforms the baseline BERT model in multiple datasets, even with significantly lower pre-training budgets.
Lazyformer [21] proposes an innovative approach where attention maps are shared between adjacent layers to reduce computational costs. This is achieved by setting 𝑔(·) to identity and alternately switching between the settings of 𝑤 1 = 0, 𝑤 2 = 1 and 𝑤 1 = 1, 𝑤 2 = 0. This method enables the computation of attention maps only once and reuses them in succeeding layers. The pre-training experiments conducted by Lazyformer show that their model is not only efficient but also effective, outperforming the baseline models with significantly lower computation budgets.
3.5.3. Prior as multi-task adapters: The Prior as Multi-task Adapters approach uses trainable attention priors that enable efficient parameter sharing across tasks [22]. The Conditionally Adaptive Multi-Task Learning (CAMTL) [23] framework is a technique for multi-task learning that enables the efficient sharing of pre-trained models between tasks. CAMTL uses trainable attention prior, which depends on task encoding, to act as an adapter for multi-task inductive knowledge transfer. Specifically, the attention prior is represented as a block diagonal matrix that is added to the attention scores of upper layers in pre-trained Transformers:
in which, ⊕ represents direct sum, 𝐴𝑗 are trainable parameters with dimensions (𝑛/𝑚)×(𝑛/𝑚) and 𝛾𝑗 and 𝛽𝑗 are Feature Wise Linear Modulation functions with input and output dimensions of R𝐷𝑧 and (𝑛/𝑚)×(𝑛/𝑚), respectively [24]. The CAMTL framework specifies a maximum sequence length 𝑛𝑚𝑎𝑥 in implementation. The attention prior, which is a trainable matrix, is added to the attention scores of the upper layers in pre-trained Transformers. This addition creates an adapter that allows for parameter-efficient multi-task inductive knowledge transfer. The prior is organized as a block diagonal matrix for efficient computation.
3.5.4. Attention with only prior: Zhang et al. [25] have developed an alternative approach to attention distribution that does not rely on pair-wise interaction between inputs. Their method is called the “average attention network,” and it uses a discrete uniform distribution as the sole source of attention distribution. The values are then aggregated as a cumulative average of all values. To enhance the network’s expressiveness, a feed-forward gating layer is added on top of the average attention module. The benefit of this approach is that the modified Transformer decoder can be trained in a parallel manner, and it can decode like an RNN, avoiding the O(𝑇²) complexity associated with decoding.
similar to Yang et al. [17] and Guo et al. [18], which use a fixed local window for attention distribution, You et al. [26] incorporate a hardcoded Gaussian distribution attention for attention calculation. However, They completely ignore the calculated attention and solely use the Gaussian distribution for attention computation in which, the mean and variance are the hyperparameters. Provided it is implemented on self-attention, it can produce results close to the baseline models in machine translation tasks.
Synthesizer [27] has proposed a novel way of generating attention scores in Transformers. Instead of using the traditional method of generating attention scores, they replace them with two variants: (1) learnable, randomly initialized attention scores, and (2) attention scores output by a feed-forward network that is only conditioned on the input being queried. The results of their experiments on machine translation and language modeling tasks demonstrate that these variants perform comparably to the standard Transformer model. However, the reason why these variants work is not fully explained, leaving room for further investigation.
3.6. Improved multi-head mechanism
Multi-head attention is a powerful technique because it allows a model to attend to different parts of the input simultaneously. However, it is not guaranteed that each attention head will learn unique and complementary features. As a result, some researchers have explored methods to ensure that each attention head captures distinct information.
3.6.1. Head behavior modeling: Multi-head attention is a useful tool in natural language processing models as it enables the simultaneous processing of multiple inputs and feature representations [28]. However, the vanilla Transformer model lacks a mechanism to ensure that different attention heads capture distinct and non-redundant features. Additionally, there is no provision for interaction among the heads. To address these limitations, recent research has focused on introducing novel mechanisms that guide the behavior of attention heads or enable interaction between them.
In order to promote diversity among different attention heads, Li et al. [29] propose an additional regularization term in the loss function. This regularization consists of two parts: the first two aim to maximize the cosine distances between input subspaces and output representations, while the latter encourages dispersion of the positions attended by multiple heads through element-wise multiplication of their corresponding attention matrices. By adding this auxiliary term, the model is encouraged to learn a more diverse set of attention patterns across different heads, which can improve its performance on various tasks.
Numerous studies have shown that pre-trained Transformer models exhibit certain self-attention patterns that do not align well with natural language processing. Kovaleva et al. [30] identify several of these patterns in BERT, including attention heads that focus exclusively on the special tokens [CLS] and [SEP]. To improve training, Deshpande and Narasimhan [31] suggest using an auxiliary loss function that measures the Frobenius norm between the attention distribution maps and predefined attention patterns. This approach introduces constraints to encourage more meaningful attention patterns.
In the paper by Shen et al. [32], a new mechanism called Talking-head Attention is introduced, which aims to encourage the model to transfer information between different attention heads in a learnable manner. This mechanism involves linearly projecting the generated attention scores from the hidden dimension to a new space with h_k heads, applying softmax in this space, and then projecting the results to another space with h_v heads for value aggregation. This way, the attention mechanism can learn to dynamically transfer information between the different attention heads, leading to improved performance in various natural language processing tasks.
Collaborative Multi-head Attention is a mechanism proposed in [33] that involves the use of shared query and key projections, denoted as W𝑄 and W𝐾, respectively, along with a mixing vector m𝑖. This mixing vector is used to filter the projection parameters for the 𝑖-th head. Specifically, the attention computation is adapted to reflect this mechanism, resulting in a modified equation (3).
where all heads share W^q and W^k.
3.6.2. Multi-head with restricted spans:
The vanilla attention mechanism typically assumes full attention spans, allowing a query to attend to all key-value pairs. However, it has been observed that some attention heads tend to focus more on local contexts, while others attend to broader contexts. As a result, it may be advantageous to impose constraints on attention spans for specific purposes:
- Locality: Restricting attention spans can explicitly impose local constraints, which can be beneficial in scenarios where locality is an important consideration.
- Efficiency: Appropriately implemented, such a model can scale to longer sequences without introducing additional memory usage or computational time.
Restricting attention spans involves multiplying each attention distribution value with a mask value, followed by re-normalization. The mask value can be determined by a non-increasing function that maps a distance to a value in the range [0, 1]. In vanilla attention, a mask value of 1 is assigned for all distances, as illustrated in Figure 12(a).
In a study by Sukhbaatar et al. [34], a novel approach was proposed, introducing a learnable attention span that is depicted in the intriguing Figure 12(b). This innovative technique utilizes a mask parameterized by a learnable scalar 𝑧, combined with a hyperparameter 𝑅, to adaptively modulate the attention span. Remarkably, experimental results on character-level language modeling demonstrated that these adaptive-span models outperformed the baseline models while requiring significantly fewer FLOPS. Notably, an interesting observation was made that lower layers of the model tended to exhibit smaller learned spans, whereas higher layers displayed larger spans. This intriguing finding suggests that the model can autonomously learn a hierarchical composition of features, showcasing its exceptional ability to capture complex patterns and structures in the data.
The Multi-Scale Transformer [35] presents a novel approach to attention spans that challenges the traditional paradigm. Unlike vanilla attention, which assumes a uniform attention span across all heads, this innovative model introduces a fixed attention span with dynamic scaling in different layers. Illustrated in Fig. 12(c), the fixed attention span acts as a window that can be scaled up or down, controlled by a scale value denoted as 𝑤.
The scale values vary, with higher layers favoring larger scales for broader contextual dependencies and lower layers opting for smaller scales for more localized attention as shown in Figure 13. The experimental results of the Multi-Scale Transformer demonstrate its superior performance over baseline models on various tasks, showcasing its potential for more efficient and effective language processing.
3.6.3. Multi-head with refined aggregation:
The vanilla multi-head attention mechanism, as proposed by Vaswani et al. [28], involves the computation of multiple attention heads that operate in parallel to generate individual output representations. These representations are then concatenated and subjected to a linear transformation, as defined in Eq. (11), to obtain the final output representation. By combining Eqs. (10), (11), and (12), it can be observed that this concatenate-and-project formulation is equivalent to a summation over re-parameterized attention outputs. This approach allows for efficient aggregation of the diverse attention head outputs, enabling the model to capture complex dependencies and relationships in the input data.
and
where
To facilitate the aggregation process, the weight matrix W𝑂 ∈ R𝐷𝑚 ×𝐷𝑚 used for the linear transformation is partitioned into 𝐻 blocks, where 𝐻 represents the number of attention heads.
The weight matrix W𝑂_𝑖, with dimension 𝐷𝑣 × 𝐷𝑚, is used for the linear transformation in each attention head, allowing for re-parameterized attention outputs through the concatenate-and-project formulation, as defined in Eq. (14):
Some researchers may argue that the straightforward aggregate-by-summation approach may not fully leverage the expressive power of multi-head attention and that a more complex aggregation scheme could be more desirable.
Gu and Feng [36] and Li et al. [37] propose employing routing methods originally conceived for capsule networks [38] as a means to further aggregate information derived from distinct attention heads. Through a process of transforming the outputs of attention heads into input capsules and subsequently undergoing an iterative routing procedure, output capsules are obtained. These output capsules are then concatenated to serve as the final output of the multi-head attention mechanism. Notably, the dynamic routing [38] and EM routing [39] mechanisms employed in these works introduce additional parameters and computational overhead. Nevertheless, Li et al. [37] empirically demonstrate that selectively applying the routing mechanism to the lower layers of the model achieves an optimal balance between translation performance and computational efficiency.
3.6.4. Other multi-head modifications:
In addition to the aforementioned modifications, several other approaches have been proposed to enhance the performance of the multi-head attention mechanism. Shazeer [40] introduced the concept of multi-query attention, where key-value pairs are shared among all attention heads. This reduces the memory bandwidth requirements during decoding and leads to faster decoding, albeit with minor quality degradation compared to the baseline. On the other hand, Bhojanapalli et al. [41] identified that the size of attention keys could impact their ability to represent arbitrary distributions. To address this, they proposed disentangling the head size from the number of heads, contrary to the conventional practice of setting the head size as 𝐷𝑚/ℎ, where 𝐷𝑚 is the model dimension and ℎ is the number of heads.