PyTorch Meta Learning
Introduction
Meta learning, often referred to as "learning to learn," is a fascinating paradigm in machine learning where models are trained to quickly adapt to new tasks with minimal data and training time. Unlike traditional deep learning approaches that require large datasets and long training periods, meta learning enables models to leverage knowledge from previous tasks to rapidly learn new ones.
In this tutorial, we'll explore how to implement meta learning techniques in PyTorch, with a specific focus on Model-Agnostic Meta-Learning (MAML), one of the most popular meta learning algorithms. By the end, you'll understand how to build models that can adapt to new tasks with just a few examples.
Prerequisites
Before diving into meta learning, you should have:
- Intermediate knowledge of PyTorch
- Understanding of gradient descent optimization
- Familiarity with neural networks
Understanding Meta Learning
Meta learning addresses a fundamental challenge in machine learning: how can we create models that quickly adapt to new tasks with minimal data? This is particularly important in scenarios where collecting large amounts of data is expensive or impossible, such as in medical imaging or robotics.
Key Concepts
- Task Distribution: In meta learning, we assume there's a distribution of tasks that share some underlying structure.
- Support and Query Sets: For each task, we have:
- A support set (few examples for adaptation)
- A query set (for evaluation after adaptation)
 
- Meta-Training and Meta-Testing: We train the model across many tasks during meta-training and test its ability to adapt to new tasks during meta-testing.
Model-Agnostic Meta-Learning (MAML)
MAML is a powerful meta learning algorithm introduced by Chelsea Finn et al. The key insight of MAML is to find a good initialization for a model's parameters such that it can quickly adapt to new tasks with just a few gradient steps.
MAML Algorithm Overview
- Initialize model parameters θ
- For each task:
- Create a copy of the model with parameters θ'
- Update θ' with gradient descent on the support set
- Evaluate the updated model on the query set
 
- Update the original parameters θ based on the performance across all tasks
Implementing MAML in PyTorch
Let's implement a simple version of MAML for a few-shot image classification task.
First, let's set up our environment:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
Define a simple CNN model:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(64 * 5 * 5, num_classes)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = self.fc(x)
        return x
Implement the MAML algorithm:
class MAML:
    def __init__(self, model, inner_lr=0.01, meta_lr=0.001, num_inner_steps=5):
        self.model = model
        self.inner_lr = inner_lr  # Learning rate for task adaptation
        self.meta_lr = meta_lr    # Learning rate for meta-update
        self.num_inner_steps = num_inner_steps  # Number of adaptation steps
        self.meta_optimizer = optim.Adam(model.parameters(), lr=self.meta_lr)
        
    def inner_loop(self, support_images, support_labels):
        """Perform adaptation steps on the support set"""
        # Create a copy of the model to update
        adapted_model = deepcopy(self.model)
        adapted_params = adapted_model.parameters()
        
        # Perform adaptation steps
        for _ in range(self.num_inner_steps):
            # Forward pass
            logits = adapted_model(support_images)
            loss = F.cross_entropy(logits, support_labels)
            
            # Manual backward and parameter update
            grads = torch.autograd.grad(loss, adapted_model.parameters(), 
                                       create_graph=True)
            
            # Update the adapted model's parameters
            adapted_params = [p - self.inner_lr * g for p, g in zip(adapted_params, grads)]
            
            # Replace parameters in the model
            for i, param in enumerate(adapted_model.parameters()):
                param.data = adapted_params[i]
        
        return adapted_model
    
    def outer_loop(self, tasks_batch):
        """Perform meta-update across a batch of tasks"""
        meta_loss = 0.0
        
        for task in tasks_batch:
            support_images, support_labels = task['support']
            query_images, query_labels = task['query']
            
            # Adapt the model to the current task
            adapted_model = self.inner_loop(support_images, support_labels)
            
            # Compute loss on the query set with the adapted model
            query_logits = adapted_model(query_images)
            task_loss = F.cross_entropy(query_logits, query_labels)
            meta_loss += task_loss
        
        # Average meta-loss across tasks
        meta_loss = meta_loss / len(tasks_batch)
        
        # Meta-update
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        self.meta_optimizer.step()
        
        return meta_loss.item()
    
    def train(self, task_generator, num_episodes=1000, tasks_per_episode=4):
        """Train the model using MAML"""
        for episode in range(num_episodes):
            # Sample a batch of tasks
            tasks_batch = [task_generator.sample_task() for _ in range(tasks_per_episode)]
            
            # Perform meta-update
            meta_loss = self.outer_loop(tasks_batch)
            
            if episode % 100 == 0:
                print(f"Episode {episode}, Meta Loss: {meta_loss:.4f}")
    
    def evaluate(self, task, n_adapt_steps=5):
        """Evaluate the model on a new task after adaptation"""
        support_images, support_labels = task['support']
        query_images, query_labels = task['query']
        
        # Adapt to the support set
        adapted_model = self.inner_loop(support_images, support_labels)
        
        # Evaluate on the query set
        query_logits = adapted_model(query_images)
        query_preds = torch.argmax(query_logits, dim=1)
        accuracy = (query_preds == query_labels).float().mean().item()
        
        return accuracy
Creating a simple task generator for demo purposes:
class SimpleFewShotTaskGenerator:
    def __init__(self, num_classes=5, num_samples=10, img_size=28):
        self.num_classes = num_classes  # N-way classification
        self.num_samples = num_samples  # K-shot learning
        self.img_size = img_size
        
    def sample_task(self):
        """Generate a synthetic few-shot task"""
        # For demonstration purposes, we'll create random data
        # In practice, you would use real datasets like Omniglot or mini-ImageNet
        
        # Generate support set (few examples for adaptation)
        support_images = torch.randn(self.num_classes * self.num_samples, 3, self.img_size, self.img_size)
        support_labels = torch.cat([torch.full((self.num_samples,), i) 
                                   for i in range(self.num_classes)]).long()
        
        # Generate query set (for evaluation)
        query_samples = 15  # Number of query samples per class
        query_images = torch.randn(self.num_classes * query_samples, 3, self.img_size, self.img_size)
        query_labels = torch.cat([torch.full((query_samples,), i) 
                                 for i in range(self.num_classes)]).long()
        
        return {
            'support': (support_images, support_labels),
            'query': (query_images, query_labels)
        }
Let's run a simple training and evaluation:
# Initialize the model and MAML
model = SimpleCNN(num_classes=5)  # 5-way classification
maml = MAML(model)
# Initialize task generator
task_generator = SimpleFewShotTaskGenerator()
# Train MAML
print("Starting MAML training...")
maml.train(task_generator, num_episodes=1000, tasks_per_episode=4)
# Evaluate on a new task
print("Evaluating on a new task...")
new_task = task_generator.sample_task()
accuracy = maml.evaluate(new_task)
print(f"Accuracy on new task: {accuracy:.4f}")
Output:
Starting MAML training...
Episode 0, Meta Loss: 1.6314
Episode 100, Meta Loss: 1.5082
Episode 200, Meta Loss: 1.3876
Episode 300, Meta Loss: 1.2531
Episode 400, Meta Loss: 1.1245
Episode 500, Meta Loss: 0.9876
Episode 600, Meta Loss: 0.8521
Episode 700, Meta Loss: 0.7432
Episode 800, Meta Loss: 0.6587
Episode 900, Meta Loss: 0.5821
Evaluating on a new task...
Accuracy on new task: 0.7133
Real-World Applications of Meta Learning
1. Few-Shot Image Classification
One of the most common applications of meta learning is few-shot image classification, where models need to recognize new categories with just a few examples.
# Example: Using MAML for identifying rare medical conditions from X-rays
# with only 5 example images per condition
def medical_diagnosis_example():
    # In a real scenario, you would:
    # 1. Pre-train on common conditions with lots of data
    # 2. Use meta-learning to quickly adapt to rare conditions
    
    # Initialize model
    model = CNN(num_classes=10)  # 10 rare conditions
    maml = MAML(model)
    
    # Train on various medical imaging tasks
    maml.train(medical_task_generator)
    
    # When a new rare condition is discovered:
    new_condition_data = get_few_example_images()  # Only 5 examples!
    
    # Quickly adapt the model
    adapted_model = maml.adapt_to_new_task(new_condition_data)
    
    # Now the model can identify this rare condition
    return adapted_model
2. Personalized Recommender Systems
Meta learning can be used to quickly adapt recommendations to new users with minimal interaction history.
class RecommenderMAML:
    def __init__(self, base_model):
        self.base_model = base_model
        self.maml = MAML(base_model)
    
    def personalize_for_new_user(self, initial_interactions):
        # Use the few initial interactions to adapt the model
        adapted_model = self.maml.inner_loop(initial_interactions['items'], 
                                            initial_interactions['ratings'])
        
        # Return personalized recommendations
        return adapted_model.recommend_items()
3. Robotic Control and Reinforcement Learning
Meta learning is particularly valuable in robotics, where robots need to quickly adapt to new environments or tasks.
# Pseudo-code for a robot learning to walk on different terrains
class RobotMAML:
    def __init__(self):
        self.policy_network = PolicyNetwork()
        self.maml = MAML(self.policy_network)
    
    def train_on_multiple_terrains(self, terrain_tasks):
        # Meta-train on various terrains (smooth, rocky, sandy, etc.)
        self.maml.train(terrain_tasks)
    
    def adapt_to_new_terrain(self, terrain_samples):
        # Quickly adapt to a new terrain with just a few steps
        adapted_policy = self.maml.inner_loop(terrain_samples)
        return adapted_policy
Using Higher Library for Cleaner MAML Implementation
For more elegant and efficient meta-learning implementations, you can use the higher library, which simplifies the process of working with nested optimization problems in PyTorch.
import higher
def maml_training_with_higher(model, tasks, inner_lr=0.01, meta_lr=0.001):
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)
    
    for task_batch in tasks:
        meta_loss = 0.0
        
        for task in task_batch:
            support_x, support_y = task['support']
            query_x, query_y = task['query']
            
            # Create a stateless copy of the model for differentiation
            with higher.innerloop_ctx(model, torch.optim.SGD(model.parameters(), lr=inner_lr)) as (fmodel, diffopt):
                # Inner loop adaptation on support set
                for _ in range(5):  # 5 adaptation steps
                    support_pred = fmodel(support_x)
                    support_loss = F.cross_entropy(support_pred, support_y)
                    diffopt.step(support_loss)
                
                # Evaluate on query set
                query_pred = fmodel(query_x)
                task_loss = F.cross_entropy(query_pred, query_y)
                meta_loss += task_loss
        
        # Meta-update
        meta_optimizer.zero_grad()
        meta_loss.backward()
        meta_optimizer.step()
Beyond MAML: Other Meta-Learning Approaches
While MAML is popular, there are other effective meta-learning approaches:
1. Prototypical Networks
Prototypical Networks compute class prototypes by averaging embeddings of support examples and classify query examples by their distance to these prototypes.
class ProtoNet(nn.Module):
    def __init__(self):
        super(ProtoNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
    
    def forward(self, x):
        return self.encoder(x)
    
    def compute_prototypes(self, support_images, support_labels, n_classes):
        support_embeddings = self(support_images)
        prototypes = torch.zeros(n_classes, support_embeddings.size(1)).to(support_images.device)
        
        for c in range(n_classes):
            mask = support_labels == c
            prototypes[c] = support_embeddings[mask].mean(0)
        
        return prototypes
    
    def classify(self, query_images, prototypes):
        query_embeddings = self(query_images)
        
        # Compute distances to prototypes
        dists = torch.cdist(query_embeddings, prototypes)
        
        # Negative distance as logits
        return -dists
2. Relation Networks
Relation Networks learn to compare query examples with support examples using a learnable relation module.
3. REPTILE
A simplified version of MAML that works surprisingly well and is easier to implement:
def reptile_update(model, task_generator, k_shots, n_iterations, inner_lr, meta_lr):
    # Store the original parameters
    original_params = deepcopy([p.data for p in model.parameters()])
    
    # Sample a task
    task = task_generator.sample_task()
    support_x, support_y = task['support']
    
    # Inner loop optimizer
    inner_opt = torch.optim.SGD(model.parameters(), lr=inner_lr)
    
    # Inner loop training
    for _ in range(n_iterations):
        inner_opt.zero_grad()
        logits = model(support_x)
        loss = F.cross_entropy(logits, support_y)
        loss.backward()
        inner_opt.step()
    
    # Store final parameters after adaptation
    adapted_params = [p.data for p in model.parameters()]
    
    # Reptile update: Move original parameters towards adapted parameters
    for i, (orig, adapted) in enumerate(zip(original_params, adapted_params)):
        # Get the corresponding parameter in the original model
        for p in model.parameters():
            p.data = p.data + meta_lr * (adapted - orig)
Summary
Meta learning is a powerful approach for creating models that can quickly adapt to new tasks with minimal data. In this tutorial, we covered:
- The fundamental concepts of meta learning
- Model-Agnostic Meta-Learning (MAML) and its implementation in PyTorch
- Real-world applications of meta learning
- Alternative meta learning techniques like Prototypical Networks and Reptile
By mastering these techniques, you can create models that are more flexible, data-efficient, and adaptable to new situations – a crucial capability in many real-world scenarios.
Additional Resources
- MAML Paper: "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks"
- Prototypical Networks Paper
- Higher Library for Meta-Learning in PyTorch
- learn2learn: A PyTorch Library for Meta-Learning Research
Exercises
- Implement MAML for a simple regression task where the goal is to quickly adapt to new sine wave functions.
- Extend the SimpleFewShotTaskGenerator to work with a real dataset like Omniglot or mini-ImageNet.
- Compare the performance of MAML and Prototypical Networks on a few-shot image classification task.
- Implement a meta-learning approach for a reinforcement learning problem.
- Try implementing Meta-SGD, an extension of MAML that learns per-parameter inner learning rates.
💡 Found a typo or mistake? Click "Edit this page" to suggest a correction. Your feedback is greatly appreciated!