Skip to main content

PyTorch Model Versioning

When developing machine learning models with PyTorch, you'll often create multiple versions of your model during the experimentation and improvement process. Properly versioning your models is crucial for tracking progress, reproducing results, and deploying the best version to production.

Why Model Versioning Matters

Imagine spending weeks training various iterations of a model, only to lose track of which version performed best or which hyperparameters produced specific results. Model versioning helps you:

  1. Track progress across different training runs
  2. Reproduce results for scientific validation
  3. Roll back to previous versions if needed
  4. Share specific versions with team members
  5. Deploy the right model to production

Basic Model Versioning Techniques

Let's start with simple versioning approaches that require minimal setup.

1. Including Version Information in Filenames

The simplest approach is to include version information directly in your model's filename:

python
# Saving a model with version information in the filename
version = "v1.0"
epoch = 10
accuracy = 0.85

filename = f"model_{version}_epoch_{epoch}_acc_{accuracy:.2f}.pth"
torch.save(model.state_dict(), filename)

# Example output filename: model_v1.0_epoch_10_acc_0.85.pth

This method is straightforward but becomes unwieldy as you create more versions.

2. Saving Version Information in the Model Checkpoint

You can include version metadata within your saved model:

python
# Creating a dictionary with model state and metadata
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss_value,
'accuracy': accuracy,
'version': "v1.0",
'description': "Initial model with batch normalization",
'date_created': str(datetime.datetime.now())
}

# Save the checkpoint
torch.save(checkpoint, "model_checkpoint.pth")

# Later, load the model with its metadata
checkpoint = torch.load("model_checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
version = checkpoint['version']
description = checkpoint['description']

print(f"Loaded model version {version}: {description}")
# Output: Loaded model version v1.0: Initial model with batch normalization

Advanced Model Versioning

As your projects grow, you'll need more sophisticated versioning strategies.

1. Creating a Versioning Directory Structure

Organize your models in a structured directory format:

python
import os

def save_model_version(model, optimizer, metrics, version, description):
# Create version directory if it doesn't exist
version_dir = f"models/version_{version}"
os.makedirs(version_dir, exist_ok=True)

# Save model and metadata
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'description': description,
'date_created': str(datetime.datetime.now())
}

# Save the checkpoint and metadata files
torch.save(checkpoint, f"{version_dir}/model.pth")

# Create a metadata file for easy reading
with open(f"{version_dir}/metadata.txt", 'w') as f:
f.write(f"Version: {version}\n")
f.write(f"Description: {description}\n")
f.write(f"Date: {checkpoint['date_created']}\n")
f.write(f"Metrics: {metrics}\n")

return version_dir

# Usage example
model_info = {
'accuracy': 0.87,
'loss': 0.34,
'f1_score': 0.86
}

save_path = save_model_version(
model=model,
optimizer=optimizer,
metrics=model_info,
version="2.1",
description="Added dropout layers to prevent overfitting"
)

print(f"Model saved to {save_path}")
# Output: Model saved to models/version_2.1

2. Using Version Control Systems with Model Checkpoints

Git isn't ideal for large binary files like model checkpoints, but you can use Git LFS (Large File Storage) or track model metadata files while keeping the actual models in a designated directory:

python
# Create a JSON file with model metadata that can be tracked in Git
import json

def create_model_metadata(version, metrics, description, model_path):
metadata = {
'version': version,
'metrics': metrics,
'description': description,
'date_created': str(datetime.datetime.now()),
'model_path': model_path
}

# Write metadata to a JSON file
with open(f"metadata/model_v{version}.json", 'w') as f:
json.dump(metadata, f, indent=4)

return f"metadata/model_v{version}.json"

# Usage
metadata_file = create_model_metadata(
version="3.0",
metrics={'accuracy': 0.89, 'loss': 0.31},
description="Fine-tuned the last two layers only",
model_path="models/version_3.0/model.pth"
)

print(f"Model metadata saved to {metadata_file}")
# Output: Model metadata saved to metadata/model_v3.0.json

Implementing Semantic Versioning for Models

Adopt semantic versioning (MAJOR.MINOR.PATCH) to communicate the significance of model changes:

  • MAJOR: Significant architecture changes
  • MINOR: Feature additions or improvements with backward compatibility
  • PATCH: Bug fixes or minor improvements with no functionality changes
python
def increment_version(current_version, update_type='patch'):
major, minor, patch = map(int, current_version.split('.'))

if update_type == 'major':
return f"{major + 1}.0.0"
elif update_type == 'minor':
return f"{major}.{minor + 1}.0"
else: # patch
return f"{major}.{minor}.{patch + 1}"

# Example usage
current_version = "1.2.3"
new_version = increment_version(current_version, 'minor')
print(f"Updated from {current_version} to {new_version}")
# Output: Updated from 1.2.3 to 1.3.0

Real-World Example: Training with Automatic Versioning

Let's integrate versioning into a complete training loop:

python
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
import datetime

class ModelVersionManager:
def __init__(self, base_dir="model_versions"):
self.base_dir = base_dir
self.version_file = os.path.join(base_dir, "version_history.json")
os.makedirs(base_dir, exist_ok=True)

# Initialize or load version history
if os.path.exists(self.version_file):
with open(self.version_file, 'r') as f:
self.version_history = json.load(f)
else:
self.version_history = {
'latest_version': '0.1.0',
'versions': {}
}
self._save_version_history()

def _save_version_history(self):
with open(self.version_file, 'w') as f:
json.dump(self.version_history, f, indent=4)

def create_new_version(self, model, optimizer, metrics,
description, update_type='patch'):
# Get new version number
latest = self.version_history['latest_version']
new_version = self._increment_version(latest, update_type)

# Create version directory
version_dir = os.path.join(self.base_dir, f"v{new_version}")
os.makedirs(version_dir, exist_ok=True)

# Save model and optimizer
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'description': description,
'date_created': str(datetime.datetime.now())
}

checkpoint_path = os.path.join(version_dir, "model_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)

# Update version history
self.version_history['latest_version'] = new_version
self.version_history['versions'][new_version] = {
'metrics': metrics,
'description': description,
'path': checkpoint_path,
'date': str(datetime.datetime.now()),
'previous_version': latest
}

self._save_version_history()
return new_version

def _increment_version(self, version, update_type):
major, minor, patch = map(int, version.split('.'))

if update_type == 'major':
return f"{major + 1}.0.0"
elif update_type == 'minor':
return f"{major}.{minor + 1}.0"
else: # patch
return f"{major}.{minor}.{patch + 1}"

def load_version(self, version=None):
# If no version specified, load latest
if version is None:
version = self.version_history['latest_version']

if version not in self.version_history['versions']:
raise ValueError(f"Version {version} not found")

version_info = self.version_history['versions'][version]
checkpoint = torch.load(version_info['path'])

return checkpoint, version_info

# Example usage in a training pipeline
def train_model(model, train_loader, epochs=5, version_manager=None):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0

for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += loss.item()

epoch_loss = running_loss / len(train_loader)
epoch_accuracy = correct / total

print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

# Create a new model version every 2 epochs with improvements
if (epoch + 1) % 2 == 0 and version_manager:
metrics = {
'loss': epoch_loss,
'accuracy': epoch_accuracy,
'epoch': epoch+1
}

# Decide on update type based on improvement
if epoch_accuracy > 0.9:
update_type = 'minor' # Significant improvement
else:
update_type = 'patch' # Minor improvement

version = version_manager.create_new_version(
model=model,
optimizer=optimizer,
metrics=metrics,
description=f"Model after epoch {epoch+1}",
update_type=update_type
)

print(f"Created new model version {version}")

return model

# Initialize model and version manager
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)

version_manager = ModelVersionManager(base_dir="mnist_model_versions")

# After training:
# train_model(model, train_loader, epochs=6, version_manager=version_manager)

# Later, load a specific version
# checkpoint, version_info = version_manager.load_version("1.0.0")
# model.load_state_dict(checkpoint['model_state_dict'])
# print(f"Loaded model version {version_info['description']}")

This comprehensive example demonstrates a complete versioning system that:

  • Tracks model versions with semantic versioning
  • Saves checkpoints with metadata
  • Creates new versions based on training progress
  • Allows loading specific versions
  • Maintains a version history with relationships between versions

Integration with Model Registry Systems

For production environments, consider integrating with dedicated model registry tools:

python
# Example integration with MLflow
import mlflow
import torch
import torch.nn as nn

# Start an MLflow run
with mlflow.start_run(run_name="model_training"):
# Define and train your model
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)

# Log parameters, metrics, and model
mlflow.log_param("hidden_size", 5)
mlflow.log_param("input_size", 10)
mlflow.log_metric("accuracy", 0.92)

# Log the model with PyTorch flavor
mlflow.pytorch.log_model(model, "model")

# Get the run ID for future reference
run_id = mlflow.active_run().info.run_id
print(f"Model logged with run_id: {run_id}")

# Later, load the model by run_id
loaded_model = mlflow.pytorch.load_model(f"runs:/{run_id}/model")

Summary

Effective model versioning is essential for tracking the evolution of your PyTorch models. In this guide, we've explored:

  1. Basic versioning techniques using filenames and metadata
  2. Advanced versioning strategies with directory structures and semantic versioning
  3. A complete versioning system for training workflows
  4. Integration with model registry tools like MLflow

By implementing proper versioning practices, you'll be able to maintain a clear history of your model development, easily roll back to previous versions, and confidently deploy the right models to production.

Additional Resources and Exercises

Resources

Exercises

  1. Basic Versioning: Create a simple script that trains a PyTorch model for MNIST classification and saves checkpoints after every epoch with proper version information.

  2. Versioning Library: Extend the ModelVersionManager class to include functionality for comparing different model versions based on their metrics.

  3. Integration Challenge: Integrate your model versioning with either MLflow or DVC and test the workflow of training, versioning, and retrieving specific model versions.

  4. Collaborative Exercise: Set up a Git repository where multiple team members can work on the same model while maintaining proper versioning of both code and model checkpoints.

By mastering model versioning in PyTorch, you'll build better habits that will serve you well throughout your machine learning career!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)