How To Reduce The Memory Usage Of The Self-Attention
With a bit of magic, we take a very inefficient computation like the Self-Attention and make it super memory-optimized for the specific hardware we use for training and inference. And we all need a bit of magic!
Self-attention Does Not Need O(N2) Memory
The GPU Architecture
The FlashAttention-1
The FlashAttention-2
The FlashAttention-3
So far, we have mainly explored how to reduce the complexity of the attention mechanisms by approximating the vanilla attention. The vanilla attention has a strict O(N2) time complexity, but the space complexity doesn't need to be O(N2)! Computing QTK requires ~O(N2) operations, but the full N x N alignment scores and attention matrices do not need to be fully materialized all at once in memory. As models and sequence lengths scale, it becomes essential to minimize the memory requirements at training and inference time to better utilize the underlying hardware.
Self-attention Does Not Need O(N2) Memory
Let's consider again the computation of the context vectors:
Here is how we arrive at the typical O(N2) space complexity:
The typical assumption is that we first compute the dot product between the query qi and all the keys [k1, …,kN]:
\( \mathbf{e}_i = \left[\frac{\mathbf{q}_i^\top \mathbf{k}_1}{\sqrt{d_\text{model}}}, \ldots, \frac{\mathbf{q}_i^\top \mathbf{k}_N}{\sqrt{d_\text{model}}}\right] \)Where ei is the alignment score vector of size N for the query qi, which leads to the N x N matrix for the N queries.
We then perform the softmax transformation:
\( \mathbf{a}_i = \frac{1}{\sum_{j=1}^N\exp\left(\frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_\text{model}}}\right)}\left[\exp\left(\frac{\mathbf{q}_i^\top \mathbf{k}_1}{\sqrt{d_\text{model}}}\right), \ldots, \exp\left(\frac{\mathbf{q}_i^\top \mathbf{k}_N}{\sqrt{d_\text{model}}}\right)\right] \)Here, ai is the attention vector of size N for the query qi. Again, for N queries, it leads to the typical N x N attention matrix.
And finally, we project ai onto the different values [v1, …,vN], which leads to ci:
\(\begin{align} \mathbf{c}_i &= \mathbf{a}_i^\top V\nonumber\\ &=\sum_{j=1}^Na_{ij}\mathbf{v}_j \end{align}\)
Therefore, naively computing the alignment scores and the attention matrices first forces the materialization of those matrices in memory, which leads to the O(N2) space complexity.
However, we do not need to order the computations in this manner! In 2021, Rabe and Staats realized that by reordering the operations, we can greatly reduce the requirements on the memory. The idea is to consider the unnormalized context vector c̃i and the normalization constant ∑exp(qiTkj) of the softmax transformation separately:
Because c̃i and si are just sums, we can easily loop through the key-value pairs to compute the context vector:
At any point during the for-loop, we only need to store the intermediary values of c̃i and si. c̃i is a vector of size dmodel (ignoring heads for simplicity), and si is a scalar. Therefore, for one query, we need constant space complexity O(1) to compute one context vector. Even iterating through all the queries, we never need to capture more than the intermediary values of c̃i and si, so we can compute the full attention mechanism in O(1) space complexity:
In reality, this is not a practical solution because sequential operations are not adapted to the parallelization capability of the CPU, GPU, or TPU hardware that is commonly used for neural network computations. In practice, the queries, keys, and values are partitioned into chunks to allow for a high degree of parallelization while keeping the memory requirement low. Let's assume that we partition the queries into nq chunks and the keys and values into nk chunks:
where each Qi is a N / nq x dmodel matrix and Ki, Vi are N / nk x dmodel matrices. Let's call Nq = N / nq, the number of queries per chunk, and Nk = N / nk, the number of key-value pairs per chunk. We can now iterate through the chunks exactly in the same way:
As before, we need to store intermediary values of C̃i and Si. In this context, QiKTj is a matrix of size Nq x Nk, and so is Aij. C̃i is a matrix of size Nq x dmodel and Si is a vector of size Nq. Therefore the space complexity is O(Nq x Nk + Nq x dmodel). To balance the number of chunks and the number of key-value pairs per chunk, they chose Nk = nk = √N and fixed Nq = 1024. This results in a space complexity:
This approach allows for efficient tensor operations within each chunk while dramatically reducing the peak memory requirements. Note that no approximation has been made, and it is mathematically equivalent to the vanilla attention mechanism. However, this approach is slower (8-13% slower during the forward pass and 30-35% slower during the backward pass) due to the sequential computations, but it enables the processing of much longer sequences that would otherwise be impossible due to memory constraints. Along with the FlashAttention, it is one of the memory optimization strategies used in the xFormers package developed by Meta and used in the development of the Llama models.
Stabilizing The Computations
Until now, we have been ignoring the numerical stability of the softmax computation, but most implementations (PyTorch, TensorFlow, ...) are using a couple of tricks to ensure its stability. Let's remind ourselves of the softmax function:
Computing exi can be tricky because if xi ≥ 89, then exi ≅ 4.4e38, which exceeds the floating-point limit of 3.4e38 for a 32-bit float number, potentially leading to float overflow errors. To prevent this from happening, we typically modify the exponent by finding the maximum xi value:
Keep reading with a 7-day free trial
Subscribe to The AiEdge Newsletter to keep reading this post and get 7 days of free access to the full post archives.