Attention is all you need
Attention is a differentiable lookup. Given a query, you compare it to a set of keys, normalize the scores, and use them to weight the corresponding values. That’s the whole mechanism — everything else (multi-head, masking, positions) is plumbing.
Scaled dot-product attention
For matrices , , and :
The scaling matters: dot products grow as increases, pushing the softmax into saturation regions where gradients vanish.
What it looks like in practice
For a single sentence attending to itself, you get an matrix where row shows where token attends. Bright cells = high weight.
Notice how red attends strongly to mat (it modifies it) and soft attends
to mat for the same reason. cat attends to sat and mat (subject and
location). These patterns aren’t programmed — they emerge from training.
The block diagram
flowchart LR
X[Input embeddings] --> Q[Linear → Q]
X --> K[Linear → K]
X --> V[Linear → V]
Q --> S["QKᵀ / √dₖ"]
K --> S
S --> SM[softmax]
SM --> A[× V]
V --> A
A --> O[Output]
Multi-head attention runs this times in parallel with smaller dimensions and concatenates:
where each .
Reference implementation
Bare-bones, no batching tricks:
import torch
import torch.nn.functional as F
from math import sqrt
def attention(Q, K, V, mask=None):
"""
Q: (n, d_k) K: (m, d_k) V: (m, d_v)
returns: (n, d_v)
"""
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / sqrt(d_k) # (n, m)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # (n, m)
return weights @ V # (n, d_v)
For causal (decoder) attention, mask is lower-triangular so each position
only sees previous tokens.
Complexity
Attention is in time and in memory. That quadratic term is the reason context windows are expensive and the reason every “long context” paper exists.
| Variant | Time | Memory | Notes |
|---|---|---|---|
| Vanilla | the baseline | ||
| FlashAttention | tiling + recomputation | ||
| Linear attention | approximation, weaker | ||
| Sliding window | local only; common with global |
Note — FlashAttention has the same asymptotic time as vanilla but ~2-4x wall-clock speedup because it’s IO-aware. The bottleneck on modern GPUs is usually memory bandwidth, not flops.
Related
- distributions — softmax is a categorical distribution