Flash attention is a power optimization transformer attention mechanism that provides 15% efficiency.
Flash attention is a power optimization transformer attention mechanism that provides 15% efficiency in terms of non-approximation wall clock speed.
Since transformer models are slow and memory intensive in long sequences (time and memory complexity is quadratic in nature), draw attention (paper) provides 15% end-to-end wall clock speedup on BERT-large, 3x speedup on GPT-2.
Considering the huge amount of energy consumed in training these large models, Flash attention with software and hardware optimization can provide 15% efficiency, which is a huge advantage in terms of improvement.
The discussion below helps explain some of the basic concepts behind flash attention and how it is implemented.
Computing and memory basics
Before we delve into computing and memory, let's review them:
What is computing?
- Time spent on your GPU calculating actual floating point operations (FLOPS)
What is memory?
- Time spent transferring tensors within a GPU
Ideally, we want our gCPU to perform matrix multiplication all the time and not be restricted by memory. But in reality, computing has progressed more compared to memory and we are in a world where the gCPU sits idle waiting for data to be loaded. This is usually called bound memory operation. Please see below for the illustrative diagram representing this. Matrix multiplication is considered computation and memory stores the data (considering it as a store). Computing needs data to process and memory bandwidth must support that operation.
What is memory hierarchy?
The A100 GPU has 40–80GB high-bandwidth memory with a bandwidth of 1.5–2.0 TB/s and 192KB of on-chip SRAM with 108 streaming multiprocessors each with an estimated bandwidth around 19TB/s.
Considering the above context, the self-care architecture is linked to memory.
As for attention math, it is a softmax operation that causes memory to be limited.
- Quantitative Evidence: As you can see below, operations like softmax, dropout and masking take the most time compared to matrix multiplication (Matmul).
Why does softmax become a memory-bound operation?
The scale at which it operates is our biggest bottleneck. In the following diagram
- N -> number of tokens
- d -> number of embedding dimensions
- When Query and Key are multiplied, the attention matrix explodes to N*N, which requires a lot of memory. For reference (d ~128; N ~128k tokens; google gemini: ~1 million tokens)
Below is the algorithm to implement the personal attention mechanism.
As stated in the previous section, transferring information to HBM (writing S to HBM) and then loading it back from HBM to gCPU to compute softmax and then writing back to HBM is a lot of information traveling, making it memory-bound operation.
Along with the diagram, the steps below help explain how personal attention is calculated using matrix multiplication.
Step 1:
- I have simplified this. In practice, each token is aggregated with positional encoding to generate embeddings to feed a linear layer to generate . For illustration purposes, I used a dimension of 3 (generally ranging from 64 to 128). This is the standard transformer architecture input.
Step 2
- Key -> Key' (transpose) is calculated and multiplied by Query to get QK', which is N*N. This contains the attention of each token with the rest of the tokens. The following diagram also shows the relationship. Since these are tokens and we need to calculate the importance of each token relative to each other, the softmax operation is applied row by row to normalize it from 0 to 1.
- This step It requires transfer to HBM and is the most expensive operation. as we discussed. The entire flash attention document is about how to optimize this process.
Step 3
- Softmax(QK') * V is calculated as the final output matrix. The dimension here is the same as the key, query, and value input embeddings.
- Last row in output matrix
- 1*5 means that the embedding of “this” should be changed to incorporate relationships with other tokens.
- 2*5 means that the embedding of “es” should be changed to incorporate relationships with other tokens.
- Same as above for the rest of the other rows.
The basic idea is explained through the following diagram where the key, query and value blocks are propagated from HBM to SRAM and through some mathematical tricks (explained below), the calculation done here is not an approximate correct answer but a real one.
With this implementation, the paper can reduce speed time when accessing information in blocks without sacrificing correctness.
Algorithm behind the article: How is Flash attention implemented?
This is the most complex part of the article. Let's break this problem down into sub-aspects and dig deeper.
The following diagram divides the matrix into blocks and how each block is used to calculate the partial softmax and then correct the softmax.
- Initial Entry: Token – This is a quick attention document
- Key: 4 (tokens) x 3 (dimensions), Query: 4 (tokens) x 3 (dimensions) and Value: 4 (tokens)
Step 0
- Let's assume the memory is 24 bytes.
- SRAM will be divided into 4 blocks (Query, Key, Value and output array)
- Query, Key, Value and Output will get = 6 bytes each to store their information (12 bytes/4)
- Each dimension is 3 since each inlay cannot be broken, so
- Query: 6 bytes/ 3 (dimension) = 2. Same for value, key and output
- Therefore, (M/4d) gives the size of each block. In this case, the block size is 2. It means that 2 rows can be fetched in SRAM.
- In general sense, the block size is (M/4d) and the number of blocks is (N*4D/M)
Step 1 and 2: Add a table below illustrating steps 1 and 2 on how flash attention works and comparing the memory and compute aspects.
The following diagram helps visualize the matrix multiplication (block by block) used in flash attention.
What is the mathematical aspect of softmax?
One of the most critical aspects of the article is how matrix decomposition still results in the accuracy of the softmax calculation. Leaving the mathematical example below on how to display two different matrices, they can be used to calculate softmax again.
Intuition
- This is the beautiful property of exponents that is taken advantage of here.
- Each softmax is calculated individually but along with this row maximum value it is stored along with the summed exponent value.
- When merging with another array, we need to check how much the maximum differs from the global maximum of 2 arrays. And because of the exponent, both the numerator and denominator are adjusted with e^(current_max — global_max) to incorporate this.
The logic is quite complex and therefore we leave an example below to follow. Once familiar with an example, the above intuition will make a lot of sense.
Let's look at the complexity analysis to get an idea of how things changed.
self attention
- While calculating S = QK', it becomes an N*N matrix that must be propagated back to HRAM and then removed from HRAM.
- Therefore, O(N*N + N*N) = O(N*N) is access to HBM.
Flash attention
- Outer loop: key and query will be accessed O(Nd) times
- Inner loop: only O(Nd/M) will be needed to load from HBM as it operates in blocks
- Total: O(N*N*d*d/M)
- In practice, d is much smaller than M. d ranges from (64 to 128), while M ranges from 100 KB, and thus HBM access is optimized.
- We started with the goal of optimizing access to HBM and with this complexity analysis, we see that the document has optimized the Access to HBM by factor (d*d/M) without approximation.
Such a complex document with a huge improvement in efficiency. I hope the above explanation provides some intuition on how flash attention optimizes and improves performance. I haven't covered block sparse flash attention, how does this compare to other optimization techniques, forward pass optimization, etc.? I hope to cover it in a future post.