What Shape Does nn.Flatten() Produce After Conv2d?
After Conv2d outputting (batch, 64, 7, 7), nn.Flatten() produces (batch, 3136). Calculate: 64 × 7 × 7 = 3,136. The batch dimension is preserved; everything else is flattened.
How nn.Flatten() Works
nn.Flatten(start_dim=1) collapses dimensions from start_dim onward into a single dimension. By default, start_dim=1, which preserves the batch dimension:
(batch, C, H, W) -> (batch, C * H * W)
(batch, 64, 7, 7) -> (batch, 64 * 7 * 7) -> (batch, 3136)
Common Examples
# After typical CNN feature extractors:
(batch, 512, 7, 7) -> Flatten -> (batch, 25088) # VGG-16
(batch, 2048, 1, 1) -> Flatten -> (batch, 2048) # ResNet-50 after AvgPool
(batch, 256, 6, 6) -> Flatten -> (batch, 9216) # AlexNet
(batch, 64, 7, 7) -> Flatten -> (batch, 3136) # Small CNN
PyTorch Code
import torch
import torch.nn as nn
# CNN -> Flatten -> Linear pattern
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), # (batch, 32, 28, 28)
nn.ReLU(),
nn.MaxPool2d(2), # (batch, 32, 14, 14)
nn.Conv2d(32, 64, 3, padding=1), # (batch, 64, 14, 14)
nn.ReLU(),
nn.MaxPool2d(2), # (batch, 64, 7, 7)
nn.Flatten(), # (batch, 3136)
nn.Linear(3136, 10) # (batch, 10)
)
x = torch.randn(32, 1, 28, 28)
output = model(x)
print(output.shape) # torch.Size([32, 10])
Tip: Find the Right Size
If you are not sure what size to use for the Linear layer, run a forward pass and print the shape before Flatten:
# Quick way to find the flatten size
x = torch.randn(1, 1, 28, 28)
x = feature_extractor(x)
print(x.shape) # e.g., torch.Size([1, 64, 7, 7])
flatten_size = x.shape[1] * x.shape[2] * x.shape[3] # 3136