What Input Shape Does TransformerEncoder Expect?

TransformerEncoder expects (seq_len, batch, d_model) by default, or (batch, seq_len, d_model) with batch_first=True. Output shape matches input shape.

Default Mode (seq_len first)

import torch
import torch.nn as nn

encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

# Input: (seq_len, batch, d_model)
x = torch.randn(100, 32, 512)  # 100 tokens, batch 32, 512 features
output = encoder(x)
print(output.shape)  # torch.Size([100, 32, 512])

With batch_first=True (Recommended)

encoder_layer = nn.TransformerEncoderLayer(
    d_model=512, nhead=8, batch_first=True
)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

# Input: (batch, seq_len, d_model)
x = torch.randn(32, 100, 512)  # batch 32, 100 tokens, 512 features
output = encoder(x)
print(output.shape)  # torch.Size([32, 100, 512])

Using Masks

# Padding mask: (batch, seq_len) — True = ignore this position
padding_mask = torch.zeros(32, 100, dtype=torch.bool)
padding_mask[:, 80:] = True  # mask out positions 80-99

# Causal mask: (seq_len, seq_len) — for autoregressive models
causal_mask = nn.Transformer.generate_square_subsequent_mask(100)

output = encoder(x,
    mask=causal_mask,
    src_key_padding_mask=padding_mask
)

Common Mistake

# ERROR: batch_first mismatch
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
# Default: batch_first=False, expects (seq_len, batch, d_model)
x = torch.randn(32, 100, 512)  # This is (batch, seq_len, d_model)
output = encoder(x)  # WRONG! Will silently produce bad results

# FIX: set batch_first=True or transpose input
x = x.transpose(0, 1)  # (100, 32, 512) ✓

Related Questions

Try the Attention Calculator