How to Fix "input and target batch size don't match" in PyTorch

The error input and target batch size don\'t match means loss functions compare model predictions with targets element-by-element (or batch-element-by-batch-element). When the batch dimensions do not match, PyTorch cannot compute the loss. This typically happens due to incorrect squeezing, reshaping, or when the DataLoader returns mismatched batches.

What Causes This Error

Loss functions compare model predictions with targets element-by-element (or batch-element-by-batch-element). When the batch dimensions do not match, PyTorch cannot compute the loss. This typically happens due to incorrect squeezing, reshaping, or when the DataLoader returns mismatched batches.

Scenario 1: Accidental Squeeze Removing Batch Dimension

Using .squeeze() on a single-sample batch removes the batch dimension.

The Error

output = model(x)  # [1, 10] — batch_size=1
output = output.squeeze()  # [10] — batch dimension removed!
target = torch.tensor([3])  # [1]
loss = nn.CrossEntropyLoss()(output, target)
# ValueError: Expected input batch_size (10) to match target batch_size (1)

The Fix

output = model(x)  # [1, 10]
# Use squeeze only on specific dimensions, never the batch dim
output = output.squeeze(dim=-1)  # Only squeeze last dim if needed
# Or better: don't squeeze at all
target = torch.tensor([3])  # [1]
loss = nn.CrossEntropyLoss()(output, target)  # Works! [1, 10] vs [1]

Never use .squeeze() without specifying which dimension to squeeze. When batch_size=1, squeeze() removes the batch dimension, causing shape mismatches downstream.

Scenario 2: Wrong Reshape in Forward Pass

Hardcoding reshape dimensions instead of using batch_size from input.

The Error

class Model(nn.Module):
    def forward(self, x):
        x = self.features(x)
        x = x.view(32, -1)  # Hardcoded batch_size=32!
        return self.classifier(x)

# Fails when batch_size != 32 (e.g., last batch in epoch)

The Fix

class Model(nn.Module):
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Use actual batch size
        # Or equivalently:
        x = x.flatten(1)  # Flatten all dims except batch
        return self.classifier(x)

Always use x.size(0) or x.shape[0] for the batch dimension in reshape operations. This handles variable batch sizes including the last (often smaller) batch in an epoch.

Scenario 3: DataLoader Returning Mismatched Batch Elements

Custom Dataset returning labels with wrong shape.

The Error

class MyDataset(Dataset):
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx:idx+2]  # Bug: returns 2 labels per image!
        return image, label

# DataLoader stacks images [B, C, H, W] but labels become [B, 2]
# CrossEntropyLoss sees batch mismatch

The Fix

class MyDataset(Dataset):
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]  # Single label per image
        return image, label

# Now DataLoader produces [B, C, H, W] images and [B] labels
# Always verify shapes:
for x, y in loader:
    print(f"Input: {x.shape}, Target: {y.shape}")
    break

Verify your Dataset returns exactly one label per sample. Add shape assertions in your training loop to catch mismatches early.

Quick Debugging Checklist

# Enable anomaly detection to find the exact line
torch.autograd.set_detect_anomaly(True)

# Check tensor properties
print(f"dtype: {tensor.dtype}, device: {tensor.device}, shape: {tensor.shape}")
print(f"requires_grad: {tensor.requires_grad}")

Related Questions

Try the Shape Mismatch Solver