Transformer Architecture Visual Guide
Self-Attention Explained Step by Step
Use the interactive visualizer below to step through each stage of the transformer self-attention mechanism: input embedding, Q/K/V projection, attention score computation, softmax normalization, weighted aggregation, and multi-head attention.
What Is Self-Attention?
Self-attention allows every position in a sequence to attend to every other position in the same sequence. Unlike recurrent networks, all positions are processed in parallel — giving transformers their speed advantage and ability to capture long-range dependencies.
The Query, Key, Value Framework
Each input token is projected into three vectors using learned weight matrices:
- Query (Q) — What this token is looking for
- Key (K) — What this token offers to other tokens
- Value (V) — The actual content to be aggregated
The scaling factor sqrt(d_k) prevents dot products from growing too large in high-dimensional spaces, which would push softmax into saturation zones with near-zero gradients.
PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
def forward(self, Q, K, V, mask=None):
d_k = Q.size(-1)
# Attention scores: (batch, heads, seq, seq)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V), attn_weights
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention()
def split_heads(self, x):
B, S, D = x.shape
return x.view(B, S, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q)) # (B, H, S, d_k)
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
x, weights = self.attention(Q, K, V, mask)
B, H, S, dk = x.shape
x = x.transpose(1, 2).contiguous().view(B, S, H * dk)
return self.W_o(x), weights # (B, S, d_model)
# Usage
d_model, num_heads, seq_len, batch = 512, 8, 64, 2
x = torch.randn(batch, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)
out, attn = mha(x, x, x) # Self-attention: Q=K=V=x
print(out.shape) # (2, 64, 512)
print(attn.shape) # (2, 8, 64, 64) — attention maps per head
Key Dimensions to Know
For a standard BERT-base transformer (d_model=768, num_heads=12):
- d_k = d_model / num_heads = 768 / 12 = 64
- Q, K, V matrices: (seq_len, 64) per head
- Attention matrix: (seq_len, seq_len) — quadratic in sequence length
- Multi-head output: (seq_len, 768) after concatenation + projection
Why Multi-Head?
Multiple attention heads allow the model to jointly attend to information from different representation subspaces. Head 1 might capture syntactic dependencies, head 5 might track coreference, and head 11 might focus on positional proximity — all in parallel within a single layer.