ResNet-50 Layer Shapes — Complete Shape Trace

ResNet-50: Input 224 → Conv7×7/s2 → 112 → MaxPool/s2 → 56 → Layer1 → 56 → Layer2 → 28 → Layer3 → 14 → Layer4 → 7 → AvgPool → 1 → FC → 1000

Complete Shape Trace

Input:              (batch,    3, 224, 224)

Conv1:   7x7, s2    (batch,   64, 112, 112)   # floor((224-7+2*3)/2)+1 = 112
BN + ReLU           (batch,   64, 112, 112)
MaxPool: 3x3, s2    (batch,   64,  56,  56)   # floor((112-3+2*1)/2)+1 = 56

Layer1 (3 blocks):  (batch,  256,  56,  56)   # no spatial change
Layer2 (4 blocks):  (batch,  512,  28,  28)   # first block has stride=2
Layer3 (6 blocks):  (batch, 1024,  14,  14)   # first block has stride=2
Layer4 (3 blocks):  (batch, 2048,   7,   7)   # first block has stride=2

AdaptiveAvgPool2d:  (batch, 2048,   1,   1)   # global average pooling
Flatten:            (batch, 2048)
FC:                 (batch, 1000)               # classification head

Bottleneck Block (Used in ResNet-50)

Each "block" in Layer1-4 is a bottleneck with three convolutions:

# Example: one block in Layer3
# Input: (batch, 512, 28, 28)
Conv1x1:  reduce channels    (batch, 256, 28, 28)   # 1x1 conv
Conv3x3:  spatial processing  (batch, 256, 28, 28)   # 3x3 conv
Conv1x1:  expand channels     (batch, 1024, 28, 28)  # 1x1 conv
+ skip connection (with 1x1 projection if needed)

PyTorch Verification

import torchvision.models as models
import torch

model = models.resnet50()
x = torch.randn(1, 3, 224, 224)

# Trace through each stage
x = model.conv1(x);    print(f"conv1:  {x.shape}")  # [1, 64, 112, 112]
x = model.bn1(x);      x = model.relu(x)
x = model.maxpool(x);  print(f"pool:   {x.shape}")  # [1, 64, 56, 56]
x = model.layer1(x);   print(f"layer1: {x.shape}")  # [1, 256, 56, 56]
x = model.layer2(x);   print(f"layer2: {x.shape}")  # [1, 512, 28, 28]
x = model.layer3(x);   print(f"layer3: {x.shape}")  # [1, 1024, 14, 14]
x = model.layer4(x);   print(f"layer4: {x.shape}")  # [1, 2048, 7, 7]
x = model.avgpool(x);  print(f"avgpool:{x.shape}")  # [1, 2048, 1, 1]
x = torch.flatten(x, 1)
x = model.fc(x);       print(f"fc:     {x.shape}")  # [1, 1000]

Related Questions

Try the Parameter Counter