Transformer Architecture Visual Guide
Self-Attention Explained Step by Step

Use the interactive visualizer below to step through each stage of the transformer self-attention mechanism: input embedding, Q/K/V projection, attention score computation, softmax normalization, weighted aggregation, and multi-head attention.

What Is Self-Attention?

Self-attention allows every position in a sequence to attend to every other position in the same sequence. Unlike recurrent networks, all positions are processed in parallel — giving transformers their speed advantage and ability to capture long-range dependencies.

The Query, Key, Value Framework

Each input token is projected into three vectors using learned weight matrices:

Attention(Q, K, V) = softmax( Q * K^T / sqrt(d_k) ) * V

The scaling factor sqrt(d_k) prevents dot products from growing too large in high-dimensional spaces, which would push softmax into saturation zones with near-zero gradients.

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        # Attention scores: (batch, heads, seq, seq)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        return torch.matmul(attn_weights, V), attn_weights


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention()

    def split_heads(self, x):
        B, S, D = x.shape
        return x.view(B, S, self.num_heads, self.d_k).transpose(1, 2)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))  # (B, H, S, d_k)
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        x, weights = self.attention(Q, K, V, mask)
        B, H, S, dk = x.shape
        x = x.transpose(1, 2).contiguous().view(B, S, H * dk)
        return self.W_o(x), weights  # (B, S, d_model)


# Usage
d_model, num_heads, seq_len, batch = 512, 8, 64, 2
x = torch.randn(batch, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)
out, attn = mha(x, x, x)  # Self-attention: Q=K=V=x
print(out.shape)   # (2, 64, 512)
print(attn.shape)  # (2, 8, 64, 64) — attention maps per head

Key Dimensions to Know

For a standard BERT-base transformer (d_model=768, num_heads=12):

Why Multi-Head?

Multiple attention heads allow the model to jointly attend to information from different representation subspaces. Head 1 might capture syntactic dependencies, head 5 might track coreference, and head 11 might focus on positional proximity — all in parallel within a single layer.

Try the MultiHead Attention Calculator