What Input Shape Does BatchNorm2d Expect?
BatchNorm2d(64) expects input shape (batch, 64, H, W). The argument is the number of channels, not the spatial dimensions. It must match the channel count from the preceding Conv2d layer.
How to Set the Argument
The BatchNorm2d argument must equal out_channels of the preceding Conv2d:
# The argument to BatchNorm2d = out_channels of Conv2d
nn.Conv2d(3, 64, kernel_size=3, padding=1) # outputs (batch, 64, H, W)
nn.BatchNorm2d(64) # expects (batch, 64, H, W) ✓
nn.Conv2d(64, 128, kernel_size=3, padding=1) # outputs (batch, 128, H, W)
nn.BatchNorm2d(128) # expects (batch, 128, H, W) ✓
Parameters
BatchNorm2d(C) has very few trainable parameters:
Learnable: 2 * C (gamma and beta, one per channel)
Running stats: 2 * C (running_mean and running_var, not trained)
BatchNorm2d(64):
Trainable parameters: 2 * 64 = 128
Running stats: 2 * 64 = 128 (buffers, not parameters)
PyTorch Code
import torch
import torch.nn as nn
# Standard Conv -> BN -> ReLU block
block = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64), # matches Conv2d out_channels
nn.ReLU(inplace=True)
)
x = torch.randn(32, 3, 224, 224)
output = block(x)
print(output.shape) # torch.Size([32, 64, 224, 224])
# Output shape = input shape (BN doesn't change dimensions)
BatchNorm Variants
- BatchNorm1d(C) — expects (batch, C) or (batch, C, L) for 1D data
- BatchNorm2d(C) — expects (batch, C, H, W) for 2D images
- BatchNorm3d(C) — expects (batch, C, D, H, W) for 3D volumes