PyTorch Model Versioning
When developing machine learning models with PyTorch, you'll often create multiple versions of your model during the experimentation and improvement process. Properly versioning your models is crucial for tracking progress, reproducing results, and deploying the best version to production.
Why Model Versioning Matters
Imagine spending weeks training various iterations of a model, only to lose track of which version performed best or which hyperparameters produced specific results. Model versioning helps you:
- Track progress across different training runs
- Reproduce results for scientific validation
- Roll back to previous versions if needed
- Share specific versions with team members
- Deploy the right model to production
Basic Model Versioning Techniques
Let's start with simple versioning approaches that require minimal setup.
1. Including Version Information in Filenames
The simplest approach is to include version information directly in your model's filename:
# Saving a model with version information in the filename
version = "v1.0"
epoch = 10
accuracy = 0.85
filename = f"model_{version}_epoch_{epoch}_acc_{accuracy:.2f}.pth"
torch.save(model.state_dict(), filename)
# Example output filename: model_v1.0_epoch_10_acc_0.85.pth
This method is straightforward but becomes unwieldy as you create more versions.
2. Saving Version Information in the Model Checkpoint
You can include version metadata within your saved model:
# Creating a dictionary with model state and metadata
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss_value,
'accuracy': accuracy,
'version': "v1.0",
'description': "Initial model with batch normalization",
'date_created': str(datetime.datetime.now())
}
# Save the checkpoint
torch.save(checkpoint, "model_checkpoint.pth")
# Later, load the model with its metadata
checkpoint = torch.load("model_checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
version = checkpoint['version']
description = checkpoint['description']
print(f"Loaded model version {version}: {description}")
# Output: Loaded model version v1.0: Initial model with batch normalization
Advanced Model Versioning
As your projects grow, you'll need more sophisticated versioning strategies.
1. Creating a Versioning Directory Structure
Organize your models in a structured directory format:
import os
def save_model_version(model, optimizer, metrics, version, description):
# Create version directory if it doesn't exist
version_dir = f"models/version_{version}"
os.makedirs(version_dir, exist_ok=True)
# Save model and metadata
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'description': description,
'date_created': str(datetime.datetime.now())
}
# Save the checkpoint and metadata files
torch.save(checkpoint, f"{version_dir}/model.pth")
# Create a metadata file for easy reading
with open(f"{version_dir}/metadata.txt", 'w') as f:
f.write(f"Version: {version}\n")
f.write(f"Description: {description}\n")
f.write(f"Date: {checkpoint['date_created']}\n")
f.write(f"Metrics: {metrics}\n")
return version_dir
# Usage example
model_info = {
'accuracy': 0.87,
'loss': 0.34,
'f1_score': 0.86
}
save_path = save_model_version(
model=model,
optimizer=optimizer,
metrics=model_info,
version="2.1",
description="Added dropout layers to prevent overfitting"
)
print(f"Model saved to {save_path}")
# Output: Model saved to models/version_2.1
2. Using Version Control Systems with Model Checkpoints
Git isn't ideal for large binary files like model checkpoints, but you can use Git LFS (Large File Storage) or track model metadata files while keeping the actual models in a designated directory:
# Create a JSON file with model metadata that can be tracked in Git
import json
def create_model_metadata(version, metrics, description, model_path):
metadata = {
'version': version,
'metrics': metrics,
'description': description,
'date_created': str(datetime.datetime.now()),
'model_path': model_path
}
# Write metadata to a JSON file
with open(f"metadata/model_v{version}.json", 'w') as f:
json.dump(metadata, f, indent=4)
return f"metadata/model_v{version}.json"
# Usage
metadata_file = create_model_metadata(
version="3.0",
metrics={'accuracy': 0.89, 'loss': 0.31},
description="Fine-tuned the last two layers only",
model_path="models/version_3.0/model.pth"
)
print(f"Model metadata saved to {metadata_file}")
# Output: Model metadata saved to metadata/model_v3.0.json
Implementing Semantic Versioning for Models
Adopt semantic versioning (MAJOR.MINOR.PATCH) to communicate the significance of model changes:
- MAJOR: Significant architecture changes
- MINOR: Feature additions or improvements with backward compatibility
- PATCH: Bug fixes or minor improvements with no functionality changes
def increment_version(current_version, update_type='patch'):
major, minor, patch = map(int, current_version.split('.'))
if update_type == 'major':
return f"{major + 1}.0.0"
elif update_type == 'minor':
return f"{major}.{minor + 1}.0"
else: # patch
return f"{major}.{minor}.{patch + 1}"
# Example usage
current_version = "1.2.3"
new_version = increment_version(current_version, 'minor')
print(f"Updated from {current_version} to {new_version}")
# Output: Updated from 1.2.3 to 1.3.0
Real-World Example: Training with Automatic Versioning
Let's integrate versioning into a complete training loop:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
import datetime
class ModelVersionManager:
def __init__(self, base_dir="model_versions"):
self.base_dir = base_dir
self.version_file = os.path.join(base_dir, "version_history.json")
os.makedirs(base_dir, exist_ok=True)
# Initialize or load version history
if os.path.exists(self.version_file):
with open(self.version_file, 'r') as f:
self.version_history = json.load(f)
else:
self.version_history = {
'latest_version': '0.1.0',
'versions': {}
}
self._save_version_history()
def _save_version_history(self):
with open(self.version_file, 'w') as f:
json.dump(self.version_history, f, indent=4)
def create_new_version(self, model, optimizer, metrics,
description, update_type='patch'):
# Get new version number
latest = self.version_history['latest_version']
new_version = self._increment_version(latest, update_type)
# Create version directory
version_dir = os.path.join(self.base_dir, f"v{new_version}")
os.makedirs(version_dir, exist_ok=True)
# Save model and optimizer
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'description': description,
'date_created': str(datetime.datetime.now())
}
checkpoint_path = os.path.join(version_dir, "model_checkpoint.pth")
torch.save(checkpoint, checkpoint_path)
# Update version history
self.version_history['latest_version'] = new_version
self.version_history['versions'][new_version] = {
'metrics': metrics,
'description': description,
'path': checkpoint_path,
'date': str(datetime.datetime.now()),
'previous_version': latest
}
self._save_version_history()
return new_version
def _increment_version(self, version, update_type):
major, minor, patch = map(int, version.split('.'))
if update_type == 'major':
return f"{major + 1}.0.0"
elif update_type == 'minor':
return f"{major}.{minor + 1}.0"
else: # patch
return f"{major}.{minor}.{patch + 1}"
def load_version(self, version=None):
# If no version specified, load latest
if version is None:
version = self.version_history['latest_version']
if version not in self.version_history['versions']:
raise ValueError(f"Version {version} not found")
version_info = self.version_history['versions'][version]
checkpoint = torch.load(version_info['path'])
return checkpoint, version_info
# Example usage in a training pipeline
def train_model(model, train_loader, epochs=5, version_manager=None):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
epoch_accuracy = correct / total
print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")
# Create a new model version every 2 epochs with improvements
if (epoch + 1) % 2 == 0 and version_manager:
metrics = {
'loss': epoch_loss,
'accuracy': epoch_accuracy,
'epoch': epoch+1
}
# Decide on update type based on improvement
if epoch_accuracy > 0.9:
update_type = 'minor' # Significant improvement
else:
update_type = 'patch' # Minor improvement
version = version_manager.create_new_version(
model=model,
optimizer=optimizer,
metrics=metrics,
description=f"Model after epoch {epoch+1}",
update_type=update_type
)
print(f"Created new model version {version}")
return model
# Initialize model and version manager
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
version_manager = ModelVersionManager(base_dir="mnist_model_versions")
# After training:
# train_model(model, train_loader, epochs=6, version_manager=version_manager)
# Later, load a specific version
# checkpoint, version_info = version_manager.load_version("1.0.0")
# model.load_state_dict(checkpoint['model_state_dict'])
# print(f"Loaded model version {version_info['description']}")
This comprehensive example demonstrates a complete versioning system that:
- Tracks model versions with semantic versioning
- Saves checkpoints with metadata
- Creates new versions based on training progress
- Allows loading specific versions
- Maintains a version history with relationships between versions
Integration with Model Registry Systems
For production environments, consider integrating with dedicated model registry tools:
# Example integration with MLflow
import mlflow
import torch
import torch.nn as nn
# Start an MLflow run
with mlflow.start_run(run_name="model_training"):
# Define and train your model
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
)
# Log parameters, metrics, and model
mlflow.log_param("hidden_size", 5)
mlflow.log_param("input_size", 10)
mlflow.log_metric("accuracy", 0.92)
# Log the model with PyTorch flavor
mlflow.pytorch.log_model(model, "model")
# Get the run ID for future reference
run_id = mlflow.active_run().info.run_id
print(f"Model logged with run_id: {run_id}")
# Later, load the model by run_id
loaded_model = mlflow.pytorch.load_model(f"runs:/{run_id}/model")
Summary
Effective model versioning is essential for tracking the evolution of your PyTorch models. In this guide, we've explored:
- Basic versioning techniques using filenames and metadata
- Advanced versioning strategies with directory structures and semantic versioning
- A complete versioning system for training workflows
- Integration with model registry tools like MLflow
By implementing proper versioning practices, you'll be able to maintain a clear history of your model development, easily roll back to previous versions, and confidently deploy the right models to production.
Additional Resources and Exercises
Resources
Exercises
-
Basic Versioning: Create a simple script that trains a PyTorch model for MNIST classification and saves checkpoints after every epoch with proper version information.
-
Versioning Library: Extend the
ModelVersionManager
class to include functionality for comparing different model versions based on their metrics. -
Integration Challenge: Integrate your model versioning with either MLflow or DVC and test the workflow of training, versioning, and retrieving specific model versions.
-
Collaborative Exercise: Set up a Git repository where multiple team members can work on the same model while maintaining proper versioning of both code and model checkpoints.
By mastering model versioning in PyTorch, you'll build better habits that will serve you well throughout your machine learning career!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)