Skip to content

Latest commit

 

History

History
209 lines (162 loc) · 8.81 KB

File metadata and controls

209 lines (162 loc) · 8.81 KB

Original Meta Networks Overview

📚 What are Original Meta Networks?

Original Meta Networks is the true implementation of the Meta Networks algorithm from the paper "Meta Networks" by Munkhdalai & Yu (2017). This is a model-based meta-learning approach where one neural network (meta-learner) learns to generate the actual parameters (weights and biases) of another neural network (base learner) for task-specific classification.

Paper: Meta Networks - Munkhdalai & Yu, ICML 2017


⚠️ Important: This is the Original Meta Networks Algorithm

This implementation follows the true Meta Networks approach from the original paper:

🔍 Key Characteristics:

Aspect Original Meta Networks Embedding-based Variant
Name Meta Networks (Original) Embedding-based Meta Networks
Category 🏗️ Model-based Meta Learning 🎯 Metric-based Meta Learning
Approach Generates weights for the entire base network Generates task-specific embeddings for metric learning
Fast Weights Used as actual network parameters Used for computing similarity metrics
Meta-learner Output FC layer weights W and biases b Task-specific embeddings for classification

📝 What This Means:

  • This implementation uses the meta-learner to predict the actual weights and biases of the base network's final classification layer
  • ⚖️ Embedding-based variant (in EB_Meta_Network.py) uses the meta-learner to generate task-specific embeddings that are used in a metric-based classification approach

🏗️ Architecture

Three Main Components:

  1. Shared Embedding Network (EmbeddingNetwork.py)

    • 4 convolutional layers with batch normalization and Meta Dropout
    • Extracts fixed-dimensional feature embeddings from images
    • Input: 105×105 grayscale images
    • Output: 64-dimensional embeddings
    • No classification layer - weights generated by meta-learner
    • 🔗 Shared with Embedding-based variant for consistent comparisons
  2. Meta-Learner (MetaLearner)

    • Learns three key parameters:
      • U Matrix (hidden_dim × embedding_dim): Projects support embeddings
      • V Matrix (hidden_dim × embedding_dim): Additional transformation matrix
      • e Vector (hidden_dim): Base embedding capturing task structure
    • Generates actual FC layer parameters:
      • Weight matrix W [embedding_dim × num_classes]
      • Bias vector b [num_classes]
  3. Original Meta Network (OriginalMetaNetwork)

    • Combines EmbeddingNetwork and MetaLearner
    • End-to-end trainable system

🔄 How It Works

Training Algorithm:

For each batch of tasks:

  1. Extract embeddings from support and query sets using EmbeddingNetwork
  2. Process support embeddings through meta-learner:
    • For each support example (x_i, y_i):
      • Compute embedding: h_i = EmbeddingNetwork(x_i)
      • Project: r_i = tanh(U @ h_i + e)
    • Average per class: class_rep_c = mean(r_i for all i where y_i = c)
  3. Generate FC layer parameters:
    • For each class c, use weight_generator to create column w_c of weight matrix W
    • For each class c, use bias_generator to create bias value b_c
    • Construct W [embedding_dim × num_classes] and b [num_classes]
  4. Classify queries using generated weights:
    • For each query x:
      • Compute embedding: h = EmbeddingNetwork(x)
      • Compute logits: logits = h @ W + b
  5. Backpropagate loss to update U, V, e, weight/bias generators, and EmbeddingNetwork

Inference:

  • Single forward pass - no gradient-based adaptation needed!
  • Meta-learner directly generates task-specific classifier parameters

🎯 Key Differences from Other Algorithms

vs MAML:

  • MAML: Learns good initialization + adapts via gradients
  • Original Meta Networks: Learns to directly generate task-specific parameters
  • Speed: Meta Networks are faster at inference (no adaptation loop)
  • Approach: Weight prediction vs gradient-based fine-tuning

vs Embedding-based Meta Networks:

  • Original: Generates actual FC layer weights and biases
  • Embedding-based: Generates embeddings for metric-based classification
  • Similarity: Original uses predicted weights, Embedding-based uses distance metrics
  • Complexity: Original has weight/bias generators, Embedding-based has simpler projection

vs Prototypical Networks:

  • Original Meta Networks: Learns to generate classifier parameters
  • Prototypical Networks: Direct distance comparison to class prototypes
  • Parameters: Meta Networks have U, V, e + generators, Prototypical has none
  • Flexibility: Meta Networks can generate complex decision boundaries

🧪 Implementation Details

MetaLearner Architecture:

# Core meta-learner parameters (from original paper)
self.U = nn.Parameter(torch.randn(hidden_dim, embedding_dim) * 0.01)
self.V = nn.Parameter(torch.randn(hidden_dim, embedding_dim) * 0.01)
self.e = nn.Parameter(torch.randn(hidden_dim) * 0.01)

# Weight generator: hidden_dim -> embedding_dim (one column per class)
self.weight_generator = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(),
    nn.Linear(hidden_dim, embedding_dim)
)

# Bias generator: hidden_dim -> 1 (one bias per class)
self.bias_generator = nn.Sequential(
    nn.Linear(hidden_dim, hidden_dim // 2),
    nn.ReLU(),
    nn.Linear(hidden_dim // 2, 1)
)

Algorithm Flow:

  1. Support Processing: r_i = tanh(U @ h_i + e)
  2. Class Aggregation: class_rep_c = mean(r_i for y_i = c)
  3. Weight Generation: w_c = weight_generator(class_rep_c)
  4. Bias Generation: b_c = bias_generator(class_rep_c)
  5. Classification: logits = query_embeddings @ W + b

📊 Expected Performance

Typical Results (5-way 1-shot Omniglot):

  • Accuracy: 80-90% (competitive with MAML)
  • Training Time: ~2-5 minutes on GPU
  • Inference Speed: Fast (single forward pass)
  • Memory Usage: Moderate (no second-order gradients like MAML)

Advantages:

  • Fast inference: No adaptation loop required
  • Direct parameter generation: Learns to create good classifiers
  • End-to-end training: All components optimized jointly
  • Flexible: Can generate complex decision boundaries

Considerations:

  • ⚠️ More complex: Additional weight/bias generator networks
  • ⚠️ More parameters: U, V, e + generator networks
  • ⚠️ Training sensitivity: Requires careful initialization and learning rates

🚀 Usage Example

from Original_Meta_Network import OriginalMetaNetwork, train_original_meta_network, evaluate_original_meta_network
from utils.load_omniglot import OmniglotDataset, OmniglotTaskDataset
from utils.evaluate import plot_evaluation_results, plot_training_progress
from torch.utils.data import DataLoader

# 1. Load dataset
dataset = OmniglotDataset("omniglot/images_background")
task_dataset = OmniglotTaskDataset(dataset, n_way=5, k_shot=1, k_query=15, num_tasks=2000)
task_dataloader = DataLoader(task_dataset, batch_size=4, shuffle=True)

# 2. Create Original Meta Network
model = OriginalMetaNetwork(
    embedding_dim=64,
    hidden_dim=128,
    num_classes=5
)

# 3. Train model
model, optimizer, losses = train_original_meta_network(
    model=model,
    task_dataloader=task_dataloader,
    learning_rate=0.001
)

# 4. Evaluate model
eval_results = evaluate_original_meta_network(model, test_dataloader, num_classes=5)
plot_evaluation_results(eval_results)

🎓 Educational Value

This implementation is perfect for understanding:

  • Model-based meta-learning: One model generating parameters for another
  • Weight prediction: How neural networks can learn to create other neural networks
  • Meta-learning paradigms: Comparison between different approaches
  • Parameter generation: Understanding how fast weights are created and used

🔗 Related Implementations

This repository contains multiple meta-learning approaches:

  • MAML: Gradient-based meta-learning (MAML.py)
  • Embedding-based Meta Networks: Metric-based variant (EB_Meta_Network.py)
  • Original Meta Networks: This implementation (Original_Meta_Network.py)
  • Meta Dropout: Consistent regularization (Meta_Dropout.py)

📚 References

  1. Meta Networks: Munkhdalai & Yu, "Meta Networks", ICML 2017
  2. MAML: Finn et al., "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks", ICML 2017
  3. Prototypical Networks: Snell et al., "Prototypical Networks for Few-shot Learning", NIPS 2017
  4. Meta-Learning Survey: Hospedales et al., "Meta-Learning in Neural Networks: A Survey", TPAMI 2021

Note: This is the original Meta Networks implementation as described in the paper. For the embedding-based variant, see docs/META_NETWORKS_OVERVIEW.md.