What Input Shape Does TransformerEncoder Expect?
TransformerEncoder expects (seq_len, batch, d_model) by default, or (batch, seq_len, d_model) with batch_first=True. Output shape matches input shape.
Default Mode (seq_len first)
import torch
import torch.nn as nn
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
# Input: (seq_len, batch, d_model)
x = torch.randn(100, 32, 512) # 100 tokens, batch 32, 512 features
output = encoder(x)
print(output.shape) # torch.Size([100, 32, 512])
With batch_first=True (Recommended)
encoder_layer = nn.TransformerEncoderLayer(
d_model=512, nhead=8, batch_first=True
)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
# Input: (batch, seq_len, d_model)
x = torch.randn(32, 100, 512) # batch 32, 100 tokens, 512 features
output = encoder(x)
print(output.shape) # torch.Size([32, 100, 512])
Using Masks
# Padding mask: (batch, seq_len) — True = ignore this position
padding_mask = torch.zeros(32, 100, dtype=torch.bool)
padding_mask[:, 80:] = True # mask out positions 80-99
# Causal mask: (seq_len, seq_len) — for autoregressive models
causal_mask = nn.Transformer.generate_square_subsequent_mask(100)
output = encoder(x,
mask=causal_mask,
src_key_padding_mask=padding_mask
)
Common Mistake
# ERROR: batch_first mismatch
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
# Default: batch_first=False, expects (seq_len, batch, d_model)
x = torch.randn(32, 100, 512) # This is (batch, seq_len, d_model)
output = encoder(x) # WRONG! Will silently produce bad results
# FIX: set batch_first=True or transpose input
x = x.transpose(0, 1) # (100, 32, 512) ✓