What is few-shot learning?
Few-shot learning is a machine learning technique that aims to train models to learn new tasks or recognise new classes of objects using only a small amount of labelled data. Traditional machine learning approaches require large amounts of labelled training data to generalise to new tasks or classes. However, obtaining a large labelled dataset in real-world scenarios can be challenging, time-consuming, or impractical.
Table of Contents
Few-shot learning addresses this limitation by focusing on training models that can learn from a few examples or a small labelled dataset. The goal is to enable the model to generalise to new, unseen examples or classes with minimal additional training. Few-shot learning is instrumental when data collection is expensive and time-consuming or new classes emerge frequently.
Types of few-shot learning
There are various approaches to few-shot learning, but a popular one is meta-learning or learning-to-learn. Meta-learning involves training a model on multiple related tasks to quickly adapt to new, similar tasks with only a few labelled examples. This is achieved by training the model to learn a more generalisable representation or by learning to update its parameters efficiently based on a few examples.
Another approach is to use generative models such as generative adversarial networks (GANs) or variational autoencoders (VAEs). Given a limited number of labelled examples, these models can generate new samples in the target class, allowing the model to learn from the generated and labelled data.
Few-shot learning has shown promising results in natural language processing.
Few-shot learning has shown promising results in various domains, including computer vision, natural language processing, and reinforcement learning. It can improve the flexibility and adaptability of machine learning models in real-world applications where obtaining large labelled datasets is challenging.
Few-shot learning vs zero-shot learning
Few-shot learning and zero-shot learning are two related but distinct approaches in machine learning:
- Few-shot Learning: Few-shot learning aims to train models to learn new tasks or recognise new classes with only a small amount of labelled data. It assumes that some labelled examples are available for each new task or class. The goal is to leverage the available labelled data to generalise well to unseen examples or classes with minimal additional training. These techniques often employ meta-learning or transfer-learning approaches to learn a more generalisable representation or update rule.
- Zero-shot Learning: On the other hand, zero-shot learning focuses on recognising or classifying classes not seen during training. In zero-shot learning, the model is expected to generalise to entirely new classes for which it has never seen any labelled examples. Instead, zero-shot learning relies on auxiliary information, such as semantic attributes, class descriptions, or embeddings, to bridge the gap between the seen and unseen classes. The model learns to transfer knowledge from seen classes to unseen classes based on the provided supplementary information.
Few-shot learning addresses the problem of adapting to new tasks or classes with limited labelled data. While zero-shot learning uses auxiliary information to recognise or classify entirely new classes without any labelled examples.
Both approaches are valuable when obtaining large labelled datasets or training on all classes is challenging or impractical.
Advantages and disadvantages of few-shot learning
Advantages of Few-Shot Learning
- Flexibility and Adaptability: Few-shot learning enables models to quickly adapt to new tasks or classes with limited labelled data. It provides flexibility in scenarios where new classes emerge frequently or where data collection for each class is expensive or time-consuming.
- Reduced Annotation Effort: Few-shot learning reduces the need for extensive manual annotation of large labelled datasets. By leveraging a few labelled examples, models can generalise well to new tasks or classes, reducing the annotation effort required.
- Data-Efficient: Few-shot learning techniques are designed to use limited labelled data efficiently. They can learn from a small support set and generalise to classify or recognise examples in the query set accurately. This data efficiency is beneficial in domains where acquiring large amounts of labelled data is challenging.
- Fast Adaptation: Models are trained to adapt quickly to new tasks or classes. They can generalise from a few examples and make accurate predictions on unseen examples with minimal additional training. This rapid adaptation makes few-shot learning suitable for real-time or dynamic applications.
Disadvantages of Few-Shot Learning:
- Limited Generalisation: Although few-shot learning models can adapt to new tasks or classes with limited data, their generalisation capability may be limited compared to models trained on large labelled datasets. They may struggle with classes that significantly differ from the examples seen or lack sufficient representative samples.
- Sensitivity to Support Set Composition: Few-shot learning models heavily rely on the composition and quality of the support set. The selection of support examples and their inherent biases can influence the model’s performance. Inadequate or biased support sets may lead to suboptimal generalisation and performance.
- Difficulty with Out-of-Distribution Data: Few-shot learning models may struggle when faced with out-of-distribution data or examples that significantly differ from the training distribution. They may have trouble distinguishing between known and unknown classes and exhibit poor performance on unseen data.
- Overfitting to Support Set: Since few-shot learning models have limited data, they are prone to overfitting to the support set. They may not generalise well to unseen examples or classes if the support set is too small or not representative of the overall data distribution.
It is essential to consider these advantages and disadvantages when applying few-shot learning techniques and carefully evaluate their suitability for specific applications and data conditions.
Applications of few-shot learning
Few-shot learning has various applications across different domains.
- Image Classification: Few-shot learning can be applied to image classification tasks where the availability of labelled data is limited. It allows models to quickly adapt to new classes or recognise objects with only a few labelled examples, making it valuable when collecting large labelled datasets is challenging or expensive.
- Object Detection and Segmentation: Few-shot learning techniques can be extended to object detection and segmentation tasks. By leveraging a few annotated examples, models can learn to detect and segment novel objects in images or video frames without requiring extensive labelled data for each new class.
- Natural Language Processing (NLP): Few-shot learning is increasingly applied in NLP tasks. For example, models can be trained to recognise new categories or sentiments with a small labelled support set in text classification. Similarly, in machine translation, few-shot learning can enable models to quickly adapt to new language pairs with limited parallel training data.
- Anomaly Detection: Few-shot learning techniques can be utilised for anomaly detection, where the goal is to identify rare or unseen events or patterns. By training models on normal or representative examples and providing a few weird examples as support, models can detect and flag deviations from the expected patterns.
- Personalised Recommendation: Few-shot learning can enhance personalised recommendation systems by adapting to new user preferences or niche items with minimal data. Models can learn to make accurate recommendations by leveraging a user’s historical behaviour and a few labelled examples of the user’s preferences for new items or categories.
- Medical Diagnosis: Few-shot learning holds promise in medical diagnosis, where collecting large labelled datasets for rare diseases or conditions is often impractical. By training models on a small set of labelled examples, healthcare professionals can utilise few-shot learning to aid in diagnosing new and rare medical conditions.
- Robotics and Autonomous Systems: Few-shot learning techniques are relevant in robotics and autonomous systems, allowing robots to adapt to new tasks or objects with limited supervision quickly. This enables robots to learn new skills or recognise and interact with novel objects encountered in real-world environments.
These applications demonstrate versatility across various domains, where adapting to new tasks, classes, or environments with limited labelled data is crucial.
How do you implement few-shot learning?
Few-shot learning works by training machine learning models to quickly adapt and generalise to new tasks or classes with only a small amount of labelled data. The underlying idea is to leverage prior knowledge and transferable representations to facilitate learning from limited examples.
Here is a step-by-step explanation of how few-shot learning typically works:
- Dataset Setup: The setup consists of two main components: a support set and a query set. The support set contains a small number of labelled examples for each class/task, while the query set contains unlabeled examples for evaluation. The goal is to train a model that can generalise from the support set to classify or recognise examples in the query set accurately.
- Model Training: The training phase involves optimising the model’s parameters to learn a generalisable representation or update rule that can adapt to new tasks or classes. One common approach is meta-learning, where the model is trained on multiple meta-tasks or episodes. Each meta-task comprises a support set and a query set from different classes or tasks. After being exposed to the support set, the model is trained to perform well on the query set.
- Feature Extraction and Embeddings: The model typically employs deep neural networks to extract meaningful features or embeddings from the input data to facilitate generalisation. These embeddings aim to capture essential relevant characteristics and patterns across different tasks or classes.
- Meta-Learner Adaptation: During meta-training, the model is optimised to quickly adapt its parameters based on the support set of each meta-task. The adaptation process can involve updating the model’s internal representations, fine-tuning its parameters, or learning an initial state that allows rapid learning on new tasks.
- Inference and Evaluation: After training, the model is evaluated on the query set of each meta-task. It should be able to generalise well to new examples and accurately classify or recognise them despite having limited labelled data. Evaluation metrics such as accuracy, precision, recall, or F1 score are commonly used to assess the model’s performance.
- Transfer and Generalisation: The trained model can be deployed to new, unseen tasks or classes by providing a small support set of labelled examples specific to the target task. The model leverages its learned transferable knowledge to adapt to the new task and make predictions on the query set.
By following these steps, few-shot learning techniques enable models to generalise from limited labelled data and perform well on new, unseen tasks or classes. This ability to rapidly adapt to new scenarios is particularly valuable when obtaining large labelled datasets is challenging or impractical.
Tutorial: few-shot learning with Python
Here’s a step-by-step tutorial on few-shot learning using the popular technique of Prototypical Networks. Prototypical Networks are simple yet effective models that use the concept of prototypes to classify new examples.
Step 1: Dataset Preparation
- This tutorial will use the Omniglot dataset, which contains 1,623 handwritten characters from 50 alphabets.
- Download the dataset and split it into training and test sets. Ensure that each alphabet has a balanced representation in both settings.
# Load and split the Omniglot dataset
# Ensure that you have the dataset downloaded and extracted in the specified path
from torchvision.datasets import Omniglot
from torchvision.transforms import ToTensor
# Set the path to the Omniglot dataset
dataset_path = "path/to/omniglot/dataset"
# Define the transformations
transform = ToTensor()
# Load the Omniglot dataset
train_dataset = Omniglot(dataset_path, background=True, transform=transform, download=True)
test_dataset = Omniglot(dataset_path, background=False, transform=transform, download=True)
# Split the dataset into training and test sets
train_ratio = 0.8
train_size = int(train_ratio * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
Step 2: Model Architecture
- Prototypical Networks use a deep neural network to extract features from the input images. We will use a convolutional neural network (CNN) as our feature extractor for simplicity.
- Define a CNN architecture with several convolutional and pooling layers followed by fully connected layers. You can use popular architectures like VGG or ResNet as a starting point.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define the CNN architecture
class CNN(nn.Module):
def __init__(self, num_classes):
super(CNN, self).__init__()
# Define your CNN layers here
# Example architecture:
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool2 = nn.MaxPool2d(2)
self.fc = nn.Linear(128 * 7 * 7, num_classes)
def forward(self, x):
# Implement the forward pass of your CNN here
# Example forward pass:
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Initialize the model
num_classes = len(train_dataset.classes)
model = CNN(num_classes)
Step 3: Prototype Calculation
- Prototypes are representations of each class learned from the support set.
- During training, iterate over the support set and calculate the mean feature vector for each class. These mean vectors will serve as the prototypes.
- Compute the prototypes by averaging the feature vectors of all support set examples belonging to each class.
import torch
def calculate_prototypes(support_set):
# Get unique class labels from the support set
unique_labels = torch.unique(support_set.targets)
prototypes = {}
for label in unique_labels:
# Select examples belonging to the current label
examples = support_set.data[support_set.targets == label]
# Calculate the mean feature vector for the current class
prototype = torch.mean(examples, dim=0)
# Store the prototype for the current class
prototypes[label.item()] = prototype
return prototypes
Step 4: Distance Calculation
- During testing, calculate the Euclidean distance between the feature vector of a query example and each class prototype.
- Use the calculated distances to determine the predicted class. The class with the closest prototype is the predicted class for the query example.
import torch
def calculate_distance(query_example, prototypes):
distances = {}
for label, prototype in prototypes.items():
# Calculate the Euclidean distance between the query example and the prototype
distance = torch.norm(query_example - prototype)
distances[label] = distance
return distances
def predict_class(query_example, prototypes):
# Calculate distances between the query example and prototypes
distances = calculate_distance(query_example, prototypes)
# Select the class with the closest prototype
predicted_class = min(distances, key=distances.get)
return predicted_class
Step 5: Loss Function and Training
- Define a loss function that encourages the feature vectors of examples from the same class to be close to their class prototype while being far from prototypes of other classes.
- The most common loss function for few-shot learning is the cross-entropy loss.
- Train the network using mini-batches of support set examples. Compute the loss, backpropagate the gradients, and update the model’s parameters using gradient descent optimisation.
import torch.optim as optim
# Define the loss function
loss_fn = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 10
batch_size = 32
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch_idx, (support_set, _) in enumerate(train_loader):
optimizer.zero_grad()
# Forward pass
outputs = model(support_set)
# Calculate the prototypes
prototypes = calculate_prototypes(support_set)
# Compute the loss
loss = loss_fn(outputs, support_set.targets)
# Backward pass and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
if (batch_idx + 1) % 10 == 0:
print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 10}")
running_loss = 0.0
Step 6: Evaluation
- Evaluate the trained model on the test set by iterating over the query set and calculating the accuracy.
- Compute the accuracy by comparing the predicted class with the ground truth class labels of the query examples.
# Evaluation loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
for query_set, targets in test_loader:
# Calculate the prototypes
prototypes = calculate_prototypes(support_set)
# Forward pass
outputs = model(query_set)
# Predict the classes
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
accuracy = 100 * correct / total
print(f"Accuracy on the test set: {accuracy}%")
Step 7: Fine-tuning and Generalization
- To adapt the model to new tasks or classes with few examples, provide a small support set of labelled examples specific to the target task.
- Calculate the prototypes based on the support set, and use them to classify the query examples.
- Fine-tune the model on the target task if needed, using additional labelled examples.
# Provide a new support set for the target task
target_support_set = ...
# Calculate the prototypes based on the target support set
target_prototypes = calculate_prototypes(target_support_set)
# Use the target prototypes to classify the query set
target_query_set = ...
predictions = [predict_class(query_example, target_prototypes) for query_example in target_query_set]
# Fine-tune the model on the target task if needed
This tutorial provides a high-level overview of the steps involved in using Prototypical Networks. Implementation details such as data loading, model architecture, loss function, and training loop may vary depending on your specific framework or library. It’s recommended to consult relevant documentation and examples for a detailed implementation guide.
The future of few-shot learning
The future is promising, with several exciting directions and potential advancements. Here are some key aspects that could shape the future of this field:
- Improved Model Architectures: Researchers continuously explore novel model architectures and network designs to enhance the performance of few-shot learning. This includes developing more efficient and effective convolutional neural networks (CNNs), recurrent neural networks (RNNs), graph neural networks (GNNs), and attention mechanisms that can better capture and generalise from limited labelled data.
- Meta-Learning and Learning to Learn: Meta-learning, or learning to learn, is a prominent research area within few-shot learning. Future advancements may focus on developing meta-learning frameworks that can efficiently learn and generalise from a few examples across a wide range of tasks or classes. This includes exploring meta-learning algorithms, optimisation techniques, and memory-augmented architectures.
- Integration of Unsupervised and Self-Supervised Learning: Unsupervised and self-supervised learning methods hold great potential in few-shot learning. By leveraging unsupervised or self-supervised pretraining, models can learn helpful representations that facilitate generalisation to new tasks or classes with limited labelled data. Future research may focus on effectively combining unsupervised, self-supervised, and few-shot learning techniques to achieve better performance.
- Domain Adaptation and Transfer Learning: Domain adaptation and transfer learning techniques can be crucial in few-shot learning by enabling knowledge transfer from related domains or tasks. Advancements in domain adaptation algorithms, including domain adaptation GANs (generative adversarial networks) and domain adaptation methods for deep learning, can facilitate better adaptation and generalisation to new tasks or classes.
- Combining Few-Shot Learning with Reinforcement Learning: Combining few-shot learning and reinforcement learning holds promise in solving complex sequential decision-making problems. Future research may develop algorithms that can efficiently learn from a few examples in reinforcement learning settings, allowing agents to adapt and generalise quickly to new environments and tasks.
- Real-World Applications and Deployment: Few-shot learning techniques are increasingly applied to real-world applications, such as healthcare, robotics, and personalised services. The future lies in refining and adapting these techniques to specific domains and ensuring their practical deployment, considering robustness, interpretability, and scalability factors.
- Bridging the Gap between Few-Shot and Zero-Shot Learning: The boundaries between few-shot learning and zero-shot learning are starting to blur. Future research may focus on developing approaches to bridge the gap between these two paradigms, enabling models to leverage a few labelled examples while incorporating auxiliary information for recognising entirely new classes without any labelled data.
These are just a few potential directions to shape the future of few-shot learning. As the field advances, we can expect more innovative techniques, improved model generalisation, and increased applicability in real-world scenarios, ultimately enabling models to learn and adapt with even more limited labelled data.
Conclusion
Few-shot learning is a powerful approach that addresses the challenge of learning from limited labelled data. It enables models to adapt to new tasks quickly, recognise new classes, or generalise to unseen examples with only a few labelled examples. Few-shot learning offers several advantages, including flexibility, reduced annotation effort, data efficiency, and fast adaptation. It finds applications in image classification, object detection, NLP, anomaly detection, personalised recommendation, medical diagnosis, robotics, and more.
However, few-shot learning also has its limitations. Models may have limited generalisation to classes significantly different from the seen examples, sensitivity to support set composition, difficulty with out-of-distribution data, and the potential for overfitting. These challenges must be carefully considered and addressed in applying few-shot learning techniques.
Overall, few-shot learning provides a valuable tool for learning in scenarios where labelled data is scarce, new tasks or classes emerge frequently, or adaptation to novel environments is required. It opens up possibilities for leveraging limited labelled data effectively and efficiently, paving the way for more flexible and adaptive machine learning systems.
HI
I ran your code in this tutorial in colab; i am getting this error; can you please help fixing this issue. thanks very much.
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) -> Tensor:
–> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128×676 and 6272×15424)
Hi Krishna,
I checked the code and there is indeed a mismatch between the tensor sizes.
You will need to replace the correct size to get this to work.
Replace this:
self.fc = nn.Linear(128 * 7 * 7, num_classes)
with this:
self.fc = nn.Linear(86528, num_classes)
Let me know if this work!
All the best,
Neri
running_loss = 0.0
—> 16 for batch_idx, (support_set, _) in enumerate(train_loader):
17 optimizer.zero_grad()
18
NameError: name ‘train_loader’ is not defined
I am getting this error when I run your code. plz help
Hi Prakriti,
The train_loader is a DataLoader object which loads data into batches. You can ignore this if you aren’t batch loading your dataset.
from torch.utils.data import DataLoader
# Define your DataLoader for training data
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
Hope this helps,
Neri