Ml

Attention is all you need

updated 2026-05-02 2 min read #transformers #deep-learning #fundamentals

Attention is a differentiable lookup. Given a query, you compare it to a set of keys, normalize the scores, and use them to weight the corresponding values. That’s the whole mechanism — everything else (multi-head, masking, positions) is plumbing.

Scaled dot-product attention

For matrices QRn×dkQ \in \mathbb{R}^{n \times d_k}, KRm×dkK \in \mathbb{R}^{m \times d_k}, and VRm×dvV \in \mathbb{R}^{m \times d_v}:

Attention(Q,K,V)=softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V

The dk\sqrt{d_k} scaling matters: dot products grow as dkd_k increases, pushing the softmax into saturation regions where gradients vanish.

What it looks like in practice

For a single sentence attending to itself, you get an n×nn \times n matrix where row ii shows where token ii attends. Bright cells = high weight.

Self-attention heatmap on the sentence "the cat sat on the soft red mat"

Notice how red attends strongly to mat (it modifies it) and soft attends to mat for the same reason. cat attends to sat and mat (subject and location). These patterns aren’t programmed — they emerge from training.

The block diagram

flowchart LR
    X[Input embeddings] --> Q[Linear → Q]
    X --> K[Linear → K]
    X --> V[Linear → V]
    Q --> S["QKᵀ / √dₖ"]
    K --> S
    S --> SM[softmax]
    SM --> A[× V]
    V --> A
    A --> O[Output]

Multi-head attention runs this hh times in parallel with smaller dimensions and concatenates:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O

where each headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V).

Reference implementation

Bare-bones, no batching tricks:

import torch
import torch.nn.functional as F
from math import sqrt

def attention(Q, K, V, mask=None):
    """
    Q: (n, d_k)  K: (m, d_k)  V: (m, d_v)
    returns: (n, d_v)
    """
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / sqrt(d_k)   # (n, m)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)             # (n, m)
    return weights @ V                              # (n, d_v)

For causal (decoder) attention, mask is lower-triangular so each position only sees previous tokens.

Complexity

Attention is O(n2d)O(n^2 d) in time and O(n2)O(n^2) in memory. That quadratic term is the reason context windows are expensive and the reason every “long context” paper exists.

VariantTimeMemoryNotes
VanillaO(n2d)O(n^2 d)O(n2)O(n^2)the baseline
FlashAttentionO(n2d)O(n^2 d)O(n)O(n)tiling + recomputation
Linear attentionO(nd2)O(n d^2)O(nd)O(n d)approximation, weaker
Sliding windowO(nwd)O(n w d)O(nw)O(n w)local only; common with global

Note — FlashAttention has the same asymptotic time as vanilla but ~2-4x wall-clock speedup because it’s IO-aware. The bottleneck on modern GPUs is usually memory bandwidth, not flops.