2.Attention Mechanism
📅 2026-05-17 (created during knowledge-base reorganization) 👉 #AI #LLM #Attention #Architecture #DeepLearning 📎 Attention Is All You Need (Vaswani et al., 2017) 📎 The Illustrated Transformer (Jay Alammar) 📎 GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints 📎 DeepSeek-V2: Multi-Head Latent Attention
1. Overview
1.1. Why Attention?
The attention mechanism is the single most important computational primitive in modern AI. It is the engine that makes Transformers (and therefore every modern LLM) work.
The intuition: when a human reads a sentence, they don't process every word with equal weight. To answer "What did Mary tell John yesterday?", we focus on Mary, told, John, yesterday — not on filler words. Attention is a learnable mechanism that lets the model decide, for every token it processes, which other tokens to focus on and how much.
This solves a fundamental problem RNNs struggled with: long-range dependencies. In an RNN, the influence of a token 100 positions ago must travel through 100 hidden-state updates and is mostly washed out. With attention, any two tokens are one operation apart.
1.2. Where it sits in the stack
- This note focuses on attention as a mechanism. The broader Transformer architecture is in
1.Transformer_Architecture.md. - Many LLM-industry concepts touch on attention variants: MLA (DeepSeek), GQA (Llama, Mistral), Ring Attention (Gemini long-context). All are covered here.
2. Concept, Component, & Architecture
2.1. The Three Vectors: Query, Key, Value
For every token, the model produces three vectors via learned linear projections:
- Query (Q): "what am I looking for?"
- Key (K): "what do I contain?"
- Value (V): "what should I contribute if you decide to attend to me?"
Mechanically: $$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$
where $X$ is the matrix of token embeddings (shape (seq_len, d_model)) and $W_Q, W_K, W_V$ are learned weight matrices.
2.2. Scaled Dot-Product Attention — the Core Equation
Given Q, K, V, the attention output is:
$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$
Step-by-step:
1. Compute compatibility scores: $QK^T$ — for every pair (query position $i$, key position $j$), the dot product measures how well they match. Result is a (seq_len × seq_len) matrix.
2. Scale by $\sqrt{d_k}$: divides by the square root of the key dimension. Without this, dot products grow large with dimension and push softmax into saturation regions where gradients vanish.
3. Softmax: turns each row of scores into a probability distribution — these are the attention weights, summing to 1 across the sequence.
4. Weighted sum of Values: multiply the attention-weight matrix by V. Each token's output is a weighted average of all tokens' Values, where the weights are determined by Q-K similarity.
2.3. Why "Self-Attention"?
The "self" in self-attention means Q, K, V all come from the same sequence. When processing the sentence "the animal didn't cross the street because it was too tired", the model can let the token "it" attend strongly to "animal" — figuring out what "it" refers to.
In encoder-decoder models (e.g., translation), the decoder also uses cross-attention: Q comes from the decoder's current state, but K and V come from the encoder's output.
2.4. Multi-Head Attention (MHA)
A single attention "head" gives one set of attention weights — one perspective. Multi-head attention runs h heads in parallel, each with its own $W_Q, W_K, W_V$, then concatenates and projects:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W_O$$
Each head can specialize. Empirically, different heads attend to different linguistic phenomena — some track syntactic relationships, some track entity references, some track positional patterns.
The trick is not to use h full-sized attentions — instead, project Q/K/V to dimension d_model / h per head. So total compute is roughly the same as single-head attention, but with more flexibility.
2.5. Causal (Masked) Self-Attention
For autoregressive language modeling, a token at position $t$ must not see tokens at positions $> t$ (otherwise, "predicting" the next token is trivial — just copy it).
Causal masking zeros out the upper-triangular part of $QK^T$ (positions in the future) before softmax. Implementation: add $-\infty$ to those positions, and softmax pushes them to 0:
score = QK^T / sqrt(d_k)
score = score + mask # mask is 0 below diagonal, -inf above
weights = softmax(score)
2.6. Computational Complexity
For sequence length $L$ and hidden dimension $d$: - Computing $QK^T$: $O(L^2 \cdot d)$ - Memory for the attention matrix: $O(L^2)$
The quadratic term in $L$ is why long-context attention is expensive. Many recent innovations target this: - Sparse attention: only attend to certain positions (Longformer, BigBird). - Linear attention: approximate the softmax to make complexity $O(L \cdot d)$ (Performer, Linformer). - Sliding-window attention: each token attends only to a local window (Mistral). - Flash Attention: not a different mechanism — same math, but tiled and fused on the GPU to avoid materializing the $L \times L$ matrix in HBM. Crucial for training and inference speed. - Ring Attention (Gemini): distributes the attention computation across multiple devices using a ring topology, enabling 1M+ token contexts.
3. Modern Attention Variants for LLMs
The 2023-2026 wave of LLM innovations is largely about reducing the KV-Cache memory footprint while preserving quality. (KV-Cache is the K and V tensors stored for already-generated tokens during inference; see 1.Transformer_Architecture.md §4.)
3.1. Multi-Head Attention (MHA) — the original
- Each head has its own K and V.
- KV-Cache size:
n_layers × n_heads × head_dim × 2 (K and V) × precision × seq_len. - Quality: best.
- Memory: highest.
3.2. Multi-Query Attention (MQA)
- All Q heads share a single K and V head.
- KV-Cache shrinks by
n_heads× (e.g., 32× smaller for a 32-head model). - Quality: noticeable drop, especially for diverse content.
- Used in: PaLM, Falcon, some early Llama variants.
3.3. Grouped-Query Attention (GQA)
- A middle ground: split heads into
ggroups; heads within a group share K and V. - Used by: Llama-2 70B and later, Llama-3 (all sizes), Mistral, Mixtral.
- KV-Cache:
g-fold reduction (typically 8× —n_heads=64,n_kv_heads=8). - Quality: nearly indistinguishable from full MHA.
- The current sweet spot for production LLMs.
flowchart LR
subgraph MHA["MHA (Llama-1)"]
Q1[Q1]-->K1[K1]
Q2[Q2]-->K2[K2]
Q3[Q3]-->K3[K3]
Q4[Q4]-->K4[K4]
end
subgraph GQA["GQA (Llama-3)"]
Q1g[Q1]-->KGroup1[K1]
Q2g[Q2]-->KGroup1
Q3g[Q3]-->KGroup2[K2]
Q4g[Q4]-->KGroup2
end
subgraph MQA["MQA (PaLM, Falcon)"]
Q1m[Q1]-->KSingle[K1]
Q2m[Q2]-->KSingle
Q3m[Q3]-->KSingle
Q4m[Q4]-->KSingle
end
3.4. Multi-Head Latent Attention (MLA) — DeepSeek's innovation
- Introduced in DeepSeek-V2 (2024) and refined in V3 (2025).
- Idea: compress K and V via a low-rank latent representation. Store only the small latent in the KV-Cache; decompress to full K and V on-the-fly during attention.
- Result: KV-Cache shrinks by 93% vs. full MHA — far better than GQA.
- Quality: matches or exceeds MHA.
- Trade-off: slightly more compute (decompression), but on the right side of the memory-vs-compute curve for long contexts.
- Mechanically:
c = X · W_DKV # latent: small dimension, e.g., 512 K = c · W_UK # decompress on-the-fly V = c · W_UV - See
1.Foundation/2.LLM_Industry_Overview.md§3.3 (DeepSeek section).
3.5. Sliding-Window Attention (Mistral)
- Each token attends only to the last
Wtokens (e.g., W=4096). - Combined with stacked layers, the effective receptive field grows linearly with depth (each layer can see W more tokens back).
- Reduces both compute and memory to $O(L \cdot W)$.
3.6. Ring Attention (Gemini)
- Distributes the long sequence across multiple devices in a ring.
- Each device computes attention for its local chunk, passing K/V in ring fashion.
- Enables 1M+ and even 10M+ token contexts.
- Used in Gemini 1.5/2.5/3 to power their famous long-context capability.
3.7. Flash Attention (Dao et al., 2022)
- Not a different mechanism — same math as standard attention.
- Implementation trick: tile the computation in SRAM, never materialize the full $L \times L$ matrix in HBM.
- Result: 2-4× training speedup, 5-10× memory savings, no quality loss.
- Now the default in PyTorch (
torch.nn.functional.scaled_dot_product_attention) and every major training framework.
4. Practical Implementation
4.1. Minimal Self-Attention in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, causal: bool = True):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
self.causal = causal
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, D = x.shape
# Project to Q, K, V and split heads: (B, L, 3D) -> 3 * (B, n_heads, L, head_dim)
qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, n_heads, L, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Use Flash Attention via PyTorch SDPA (optimized; no L×L matrix materialized)
out = F.scaled_dot_product_attention(q, k, v, is_causal=self.causal)
# Result shape: (B, n_heads, L, head_dim)
# Recombine heads
out = out.transpose(1, 2).reshape(B, L, D)
return self.proj(out)
4.2. Manual computation for understanding
def manual_attention(q, k, v, mask=None):
"""Reference implementation — for intuition, not production."""
d_k = q.size(-1)
# Compatibility scores
scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # attention weights
return weights @ v, weights
5. Common Q & A
- Q: Why divide by $\sqrt{d_k}$?
- A: Without scaling, the variance of $Q \cdot K^T$ grows with $d_k$, pushing softmax into regions where one element dominates and gradients vanish. Scaling keeps the variance ~1.
- Q: What does an attention head actually "see"?
- A: It depends. In probing studies, you find heads that track syntactic dependencies (subject-verb agreement), heads that track coreference ("it" ↔ "the animal"), heads that look at the previous token, and many heads whose role is opaque.
- Q: Why is attention better than RNN?
- A: (1) Parallel — process the whole sequence at once. (2) Direct access — any two positions are one operation apart, no gradient decay. (3) Interpretable — you can visualize where the model is looking.
- Q: What's the difference between MHA, MQA, GQA?
- A: MHA gives every Q head its own K and V. MQA shares one K/V across all Q heads. GQA shares K/V within small groups of Q heads — a middle ground that gives most of MHA's quality at MQA's memory.
- Q: Will attention be replaced?
- A: Possibly. State-space models (Mamba), linear attention (RetNet), and various hybrid architectures (Jamba) are competitive on some benchmarks. As of 2026, attention still dominates but the field is exploring alternatives — especially for very long contexts where attention's $O(L^2)$ is painful.
- Q: Why does softmax saturate?
- A: Once one logit is much larger than the others, softmax assigns ~1 to that position and ~0 to the rest. Gradients with respect to the small-probability logits become tiny ("vanishing gradient" in softmax). Scaling, layer norm, and good initialization mitigate this.