Skip to main content

PyTorch Training Loop

Introduction

A training loop is a fundamental concept in deep learning that defines how models learn from data. While PyTorch provides high-level APIs like Lightning for training, understanding how to build a custom training loop gives you complete control over the training process. This guide will walk you through creating effective training loops in PyTorch, from basic implementations to advanced techniques.

What is a Training Loop?

A training loop is the iterative process where your neural network:

  1. Processes batches of data
  2. Calculates loss by comparing predictions with actual values
  3. Updates model parameters to minimize this loss
  4. Evaluates performance on validation data

Think of it as the "practice routine" your model follows to get better at its task.

Basic Training Loop Structure

Let's build a simple training loop for a classification model:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()

def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x

# Setup data loaders
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

# Initialize model, loss function and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
def train(model, train_loader, criterion, optimizer, epochs=5):
model.train() # Set the model to training mode

for epoch in range(epochs):
running_loss = 0.0

for batch_idx, (data, target) in enumerate(train_loader):
# Reset gradients
optimizer.zero_grad()

# Forward pass
output = model(data)

# Calculate loss
loss = criterion(output, target)

# Backward pass
loss.backward()

# Update weights
optimizer.step()

# Accumulate loss
running_loss += loss.item()

# Print statistics
if batch_idx % 100 == 99:
print(f'Epoch: {epoch+1}/{epochs}, Batch: {batch_idx+1}/{len(train_loader)}, Loss: {running_loss/100:.4f}')
running_loss = 0.0

# Evaluate after each epoch
evaluate(model, test_loader)

print('Training complete!')

# Evaluation function
def evaluate(model, test_loader):
model.eval() # Set the model to evaluation mode
correct = 0
total = 0

with torch.no_grad(): # No need to track gradients
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')
model.train() # Set back to training mode

# Run the training
train(model, train_loader, criterion, optimizer)

Example Output:

Epoch: 1/5, Batch: 100/938, Loss: 2.2518
Epoch: 1/5, Batch: 200/938, Loss: 2.1082
Epoch: 1/5, Batch: 300/938, Loss: 1.7882
...
Accuracy on test set: 82.34%
...
Epoch: 5/5, Batch: 900/938, Loss: 0.4321
Training complete!
Accuracy on test set: 91.76%

Breaking Down the Training Loop

Let's analyze each component of the training loop:

1. Setting Up Training Mode

model.train()

This sets your model to training mode, which enables features like dropout and batch normalization that behave differently during training versus evaluation.

2. The Epoch Loop

for epoch in range(epochs):
# Training code

An epoch represents one complete pass through the entire training dataset. Multiple epochs are usually needed for the model to learn effectively.

3. The Batch Loop

for batch_idx, (data, target) in enumerate(train_loader):
# Process one batch

Processing data in batches enables efficient computation and stochastic optimization.

4. The Training Steps

The core of training happens in these essential steps:

# Step 1: Zero gradients
optimizer.zero_grad()

# Step 2: Forward pass
output = model(data)

# Step 3: Calculate loss
loss = criterion(output, target)

# Step 4: Backpropagation
loss.backward()

# Step 5: Update weights
optimizer.step()

This sequence is crucial:

  • Zero gradients: Prevent gradients from accumulating across batches
  • Forward pass: Generate predictions from input data
  • Loss calculation: Measure prediction error
  • Backpropagation: Calculate gradients for each parameter
  • Weights update: Apply gradients to improve the model

5. Evaluation

Periodically checking your model's performance on validation data helps monitor progress and detect overfitting.

evaluate(model, test_loader)

Advanced Training Loop Techniques

As you become more comfortable with basic training loops, you can enhance them with these techniques:

1. Using Device Agnostic Code (CPU/GPU)

Make your training loop work seamlessly on both CPU and GPU:

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Move model to device
model = SimpleModel().to(device)

# Inside training loop
for batch_idx, (data, target) in enumerate(train_loader):
# Move data to device
data, target = data.to(device), target.to(device)

# Rest of training code
# ...

2. Learning Rate Scheduling

Adjust learning rates during training to improve convergence:

from torch.optim.lr_scheduler import StepLR

optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

# In your training loop
for epoch in range(epochs):
# Training code
# ...

# Step the scheduler after each epoch
scheduler.step()
print(f'Learning rate: {scheduler.get_last_lr()[0]}')

3. Early Stopping

Stop training when validation performance stops improving:

def train_with_early_stopping(model, train_loader, val_loader, patience=5):
best_val_loss = float('inf')
counter = 0

for epoch in range(100): # Maximum epochs
# Train for one epoch
train_loss = train_epoch(model, train_loader)

# Validate
val_loss = validate(model, val_loader)

if val_loss < best_val_loss:
best_val_loss = val_loss
counter = 0
# Save the best model
torch.save(model.state_dict(), 'best_model.pth')
else:
counter += 1

if counter >= patience:
print(f'Early stopping at epoch {epoch}')
break

4. Gradient Clipping

Prevent exploding gradients, especially useful for recurrent neural networks:

# Inside training loop, after loss.backward() but before optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

5. Mixed Precision Training

Speed up training and reduce memory usage with mixed precision (requires PyTorch 1.6+):

from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

# In training loop
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()

# Use autocast for mixed precision
with autocast():
output = model(data)
loss = criterion(output, target)

# Scale gradients and call backward
scaler.scale(loss).backward()

# Unscale gradients and clip
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Step optimizer and scaler
scaler.step(optimizer)
scaler.update()

Real-World Example: Image Classification with Progress Bar and Metrics

Let's see a more complete example using tqdm for progress bars and tracking multiple metrics:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

best_acc = 0.0

for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 10)

# Training phase
model.train()
running_loss = 0.0
running_corrects = 0

# Wrap your data loader with tqdm for a progress bar
train_pbar = tqdm(train_loader, desc=f'Training')

for inputs, labels in train_pbar:
inputs = inputs.to(device)
labels = labels.to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# Backward + optimize
loss.backward()
optimizer.step()

# Statistics
batch_loss = loss.item() * inputs.size(0)
batch_corrects = torch.sum(preds == labels.data).item()
running_loss += batch_loss
running_corrects += batch_corrects

# Update progress bar
train_pbar.set_postfix({
'loss': batch_loss / inputs.size(0),
'acc': batch_corrects / inputs.size(0)
})

epoch_loss = running_loss / len(train_loader.dataset)
epoch_acc = running_corrects / len(train_loader.dataset)
print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

# Validation phase
model.eval()
running_loss = 0.0
running_corrects = 0

val_pbar = tqdm(val_loader, desc=f'Validation')

# No gradient calculation for validation
with torch.no_grad():
for inputs, labels in val_pbar:
inputs = inputs.to(device)
labels = labels.to(device)

# Forward
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)

# Statistics
batch_loss = loss.item() * inputs.size(0)
batch_corrects = torch.sum(preds == labels.data).item()
running_loss += batch_loss
running_corrects += batch_corrects

# Update progress bar
val_pbar.set_postfix({
'loss': batch_loss / inputs.size(0),
'acc': batch_corrects / inputs.size(0)
})

epoch_loss = running_loss / len(val_loader.dataset)
epoch_acc = running_corrects / len(val_loader.dataset)
print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

# Save best model
if epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(model.state_dict(), 'best_model_weights.pth')
print(f'New best accuracy: {best_acc:.4f}, model saved!')

print(f'Best validation accuracy: {best_acc:.4f}')
return model

To use this training loop with a real-world dataset:

# Setup data transformations
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets (example with CIFAR-10)
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
val_dataset = datasets.CIFAR10('./data', train=False, transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

# Load pre-trained model
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10) # CIFAR has 10 classes

# Setup loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train model
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer)

Common Debugging Tips for Training Loops

When your training isn't working as expected, check for these common issues:

  1. Forgotten optimizer.zero_grad(): Gradients accumulate by default
  2. Model accidentally in eval mode: Check you've called model.train()
  3. Loss not decreasing:
    • Try a smaller learning rate
    • Check your data preprocessing
    • Verify loss function is appropriate
  4. Numerical instability:
    • Look for NaN values with torch.isnan(loss)
    • Use gradient clipping
  5. GPU out of memory:
    • Reduce batch size
    • Try mixed precision training

Summary

A well-designed PyTorch training loop is essential for effective deep learning. We've covered:

  • The fundamental structure of a training loop
  • How to implement critical components like forward/backward passes and optimization
  • Advanced techniques for improving training stability and performance
  • Real-world examples with progress tracking and metrics
  • Common pitfalls and debugging strategies

By mastering these concepts, you'll have the flexibility to implement custom training procedures for any deep learning task.

Exercises

  1. Modify the basic training loop to include a learning rate scheduler and observe how it affects training convergence.
  2. Implement a training loop with early stopping based on validation accuracy.
  3. Add weight regularization (L1 or L2) to the training loop to reduce overfitting.
  4. Create a training loop for a different task, such as image segmentation or text classification.
  5. Extend the training loop to log metrics to TensorBoard for visualization.

Additional Resources

Happy training!

💡 Found a typo or mistake? Click "Edit this page" to suggest a correction. Your feedback is greatly appreciated!