How to Fix "input and target batch size don't match" in PyTorch
The error input and target batch size don\'t match means loss functions compare model predictions with targets element-by-element (or batch-element-by-batch-element). When the batch dimensions do not match, PyTorch cannot compute the loss. This typically happens due to incorrect squeezing, reshaping, or when the DataLoader returns mismatched batches.
What Causes This Error
Loss functions compare model predictions with targets element-by-element (or batch-element-by-batch-element). When the batch dimensions do not match, PyTorch cannot compute the loss. This typically happens due to incorrect squeezing, reshaping, or when the DataLoader returns mismatched batches.
Scenario 1: Accidental Squeeze Removing Batch Dimension
Using .squeeze() on a single-sample batch removes the batch dimension.
The Error
output = model(x) # [1, 10] — batch_size=1
output = output.squeeze() # [10] — batch dimension removed!
target = torch.tensor([3]) # [1]
loss = nn.CrossEntropyLoss()(output, target)
# ValueError: Expected input batch_size (10) to match target batch_size (1)
The Fix
output = model(x) # [1, 10]
# Use squeeze only on specific dimensions, never the batch dim
output = output.squeeze(dim=-1) # Only squeeze last dim if needed
# Or better: don't squeeze at all
target = torch.tensor([3]) # [1]
loss = nn.CrossEntropyLoss()(output, target) # Works! [1, 10] vs [1]
Never use .squeeze() without specifying which dimension to squeeze. When batch_size=1, squeeze() removes the batch dimension, causing shape mismatches downstream.
Scenario 2: Wrong Reshape in Forward Pass
Hardcoding reshape dimensions instead of using batch_size from input.
The Error
class Model(nn.Module):
def forward(self, x):
x = self.features(x)
x = x.view(32, -1) # Hardcoded batch_size=32!
return self.classifier(x)
# Fails when batch_size != 32 (e.g., last batch in epoch)
The Fix
class Model(nn.Module):
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # Use actual batch size
# Or equivalently:
x = x.flatten(1) # Flatten all dims except batch
return self.classifier(x)
Always use x.size(0) or x.shape[0] for the batch dimension in reshape operations. This handles variable batch sizes including the last (often smaller) batch in an epoch.
Scenario 3: DataLoader Returning Mismatched Batch Elements
Custom Dataset returning labels with wrong shape.
The Error
class MyDataset(Dataset):
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx:idx+2] # Bug: returns 2 labels per image!
return image, label
# DataLoader stacks images [B, C, H, W] but labels become [B, 2]
# CrossEntropyLoss sees batch mismatch
The Fix
class MyDataset(Dataset):
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx] # Single label per image
return image, label
# Now DataLoader produces [B, C, H, W] images and [B] labels
# Always verify shapes:
for x, y in loader:
print(f"Input: {x.shape}, Target: {y.shape}")
break
Verify your Dataset returns exactly one label per sample. Add shape assertions in your training loop to catch mismatches early.
Quick Debugging Checklist
- Print tensor
.dtypeand.devicebefore operations - Check for in-place operations:
+=,*=,.add_(),.mul_() - Verify shapes with
print(tensor.shape)at each step - Use
torch.autograd.set_detect_anomaly(True)to pinpoint the exact operation
# Enable anomaly detection to find the exact line
torch.autograd.set_detect_anomaly(True)
# Check tensor properties
print(f"dtype: {tensor.dtype}, device: {tensor.device}, shape: {tensor.shape}")
print(f"requires_grad: {tensor.requires_grad}")