In LLMs, handling large sequences is not enough, we need to make sure the decoding process is fast. Here we explore 3 typical approaches used to speed up the decoding process:
Multi‑Query Attention
Grouped‑Query Attention
DeepSeek Multi-head Latent Attention
When generating text one token at a time, transformers face a significant performance challenge. In the original transformer architecture with Multi-Head Attention (MHA), each time we generate a new token, we need to:
Load the entire history of key (K) and value (V) matrices for each attention head
Process them against the new query
Generate the next token
Repeat
This creates a memory bandwidth bottleneck. For long sequences, we're constantly reloading massive tensors from memory, which becomes the limiting factor in generation speed. Specifically, with nhead attention heads and a sequence of length N, the original approach required loading tensors of size approximately nhead x dmodel x N, which could be gigabytes of data for long contexts. Since those K and V tensors are a bottleneck for the decoding process, we are going to look at strategies to minimize their memory requirements.
Strategies like multi‑query attention, grouped-query attention, and multi-latent attention are highly coupled with the KV-caching technique. With KV-caching, instead of recomputing the same K and V matrices over and over, we cache them in memory and update them at each iteration. Loading the KV-cache from memory becomes the dominant bottleneck instead of the compute itself.
Multi‑Query Attention
In 2019, Noam Shazeer published a variant of the multi-head attention: the multi-query attention (MQA). He realized that the model could maintain most of its capability while sharing a single set of keys and values across all heads. This insight might seem counterintuitive since the whole point of multi-head attention was to have different "perspectives" on the same information. What Shazeer discovered was that the diversity in the query projections still allows different heads to extract different information, the shared key-value store acts as a common knowledge repository, and each head can still attend to different parts of this shared repository.
Formally, this means that the projection matrix WQ is still a dmodel x dmodel matrix, but the projections WK and WV are dmodel x dhead matrices. With an incoming hidden state H of dimension dmodel x N, we have at training time the initial projections:
We then reshape the matrices into tensors to highlight the number of heads:
At inference time, we only consider the last query q’N of size dhead x nhead in the input sequence since we only need to predict the last token. The alignment scores are computed by broadcasting the matrix K’ to all the heads:
We perform the softmax transformation:
The value is again broadcasted to all heads to compute the context vector corresponding to the prediction:
The context vector is reshaped in the original dimensions:
Finally, the context vector is projected one last time to mix the information from the different heads:
Let's consider the memory access complexity (or memory bandwidth complexity). It measures the total amount of data that must be transferred between memory and compute units during the entire sequence of operations. This measures bandwidth requirements. For MHA, at each decoding step, we need to load WQ, WK, WV, and WO. They are all dmodel x dmodel matrices, so over N decoding steps, the memory access complexity is ~O(Ndmodel2). As we generate each token, we must reload the entire history. For the i-th token, we load keys and values of size i x dmodel. Summing over all N steps:
We also need to load the N input hidden states of size dmodel. So, the overall memory access complexity for MHA is
For long sequences, dmodelN2 dominates, which is the problematic bottleneck.
For MQA, loading WQ, WK, WV, and WO is the same asymptotic behavior ~O(Ndmodel2) and the input hidden states as well ~O(Ndmodel). However, the K and V matrices have a size i x dhead at the i-th decoding step, which leads to:
So, the overall memory access complexity is
As the dheadN2 term dominates for long sequences, it has a substantial impact because memory access increase leads to higher latency. For the specific experiments run in Shazeer's paper, he found that MQA achieves ~12x faster decoder inference with minimal performance loss. For inference with a sequence length of 128 tokens, MHA's decoding time per token was 46 μs whereas MQA's was 3.8 μs.
Grouped‑Query Attention
With multi-query attention, we gain in decoding speed, but we lose in performance compared to multi-head attention. The grouped-query attention provides a middle ground between MQA and MHA to keep performance high while improving decoding speed performance. GQA divides query heads into G groups, where each group shares a single key head and value head. This creates a configurable spectrum:
When G = 1, it is equivalent to MQA (single key-value head for all queries)
When G = nhead, it is equivalent to standard MHA (separate key-value for each query)
1 < G < nhead, we have the GQA sweet spot that balances efficiency and quality
It is important to understand that latency is not proportional to memory access. Memory-level parallelism refers to a computer system's ability to process multiple memory operations simultaneously rather than sequentially. There is a regime where the memory-level parallelism of the hardware is underutilized when we are trying to load too few groups, and there is a regime where we are saturating it when we are loading too many groups. Because of the parallelism, it is quite likely that having G = 1 (MQA) induces as much latency as G = 4 despite the increased predictive performance. This relationship between groups and latency is highly hardware-specific, which is why empirical testing is necessary to find the optimal configuration for any given system. The TPUs used in the original paper showed this saturation around G = 8, but different architectures might have different saturation points.
GQA achieves quality close to MHA with speed comparable to MQA. GQA is more stable during training than pure MQA, which can sometimes exhibit training instability. For example, on the T5-XXL model, the GQA-8 version presented in the original paper achieved an average performance of 47.1 across key benchmarks compared to 47.2 for MHA and 46.6 for MQA, while maintaining inference speeds much closer to MQA.
From MHA to GQA
One of the main appeals of GQA is the ability to convert a vanilla MHA into GQA with minimal effort. The ability to convert existing models rather than train new ones has accelerated the adoption of GQA across the industry. Large models like PaLM 2 and LLaMA 2 have incorporated GQA, partly because this conversion process made it practical to do so. Converting from MHA to GQA is a two steps process:
Checkpoint Conversion: We first convert the model's weights by mean-pooling the key and value projection matrices within each group.
Additional Pre-training: After conversion, we continue pre-training the model for a small fraction (about 5%) of the original training steps, using the same pre-training dataset and objectives.
This approach requires only about 5% of the original training compute to achieve comparable performance to the original MHA model. For large models that cost millions to train, this represents enormous savings.
DeepSeek’s Multi-head Latent Attention
Multi-Head Latent Attention (MLA) was introduced in DeepSeek-V2 as a way to optimize for training and inference speed while preserving the predictive performance. Even if GQA keeps performance high, it still has to balance efficiency with performance. MLA provides a solution that improves on both fronts at the same time.
Keep reading with a 7-day free trial
Subscribe to The AiEdge Newsletter to keep reading this post and get 7 days of free access to the full post archives.