Multi-class Classification Explained With 3 How To Python Tutorials [Scikit-Learn, PyTorch & Keras]

by | Aug 11, 2023 | Data Science, Machine Learning

What is multi-class classification in machine learning?

Multi-class classification is a machine learning task that aims to assign input data points to one of several predefined classes or categories. Each data point is associated with one and only one class label, making it different from multi-label classification, where data points can belong to multiple classes simultaneously.

In a multi-class classification problem, the algorithm learns from a labelled dataset where each data point is paired with the correct class label. The goal is to train a model that can generalize from this training data to predict the class labels of unseen data points accurately.

Common algorithms used for multi-class classification

  1. Logistic Regression: Despite its name, logistic regression can be extended to handle multiple classes through techniques like one-vs-rest or softmax regression.
  2. Decision Trees: Decision trees can be used for multi-class classification by extending their structure to accommodate multiple outcomes at each internal node.
  3. Random Forest: An ensemble method that builds multiple decision trees and combines their predictions to achieve more accurate results.
  4. Support Vector Machines (SVM): SVMs can be adapted to multi-class problems using techniques like one-vs-one or one-vs-rest.
  5. Neural Networks: Deep learning models, particularly neural networks, are highly versatile for multi-class classification tasks. They can consist of multiple layers and activation functions to capture complex patterns in the data.
  6. K-Nearest Neighbors (KNN): KNN is a simple algorithm that classifies data points based on the majority class among their k nearest neighbours.

Evaluation metrics for multi-class classification

  1. Accuracy: The proportion of correctly classified instances to the total cases. However, accuracy might not be suitable for imbalanced datasets.
  2. Precision, Recall, and F1-Score: These metrics provide a more nuanced view of performance, mainly when dealing with imbalanced classes.
  3. Confusion Matrix: A table that summarizes the actual vs. predicted class labels, aiding in understanding where the model is making errors.
  4. Cross-Entropy Loss (Log Loss): Often used as a loss function during training in probabilistic models like logistic regression or neural networks.
  5. Macro and Micro Averaging: Useful when dealing with class imbalances in the dataset. Macro averaging calculates metrics independently for each class and then takes their average, while micro averaging aggregates contributions from all classes.

Multi-class classification problems can be found in various domains, such as image classification, text categorization, medical diagnosis, etc. It’s vital to preprocess the data appropriately, choose an appropriate algorithm, and tune its hyperparameters to achieve the best results for a given task.

Multi-class vs multi-label classification

It’s important not to confuse multi-class and multi-label classification as they are separate machine learning tasks that assign labels to input data points. Let’s explore the differences between these two approaches:

1. Multi-Class Classification: Each data point is assigned to precisely one class out of predefined classes in multi-class classification. The goal is to predict the most appropriate class label for each input. Examples of multi-class classification tasks include:

  • Handwritten digit recognition (assigning digits 0-9)
  • Natural language processing tasks like sentiment analysis (assigning labels like positive, neutral, or negative)
  • Image classification (assigning labels to different objects in an image)

2. Multi-Label Classification: In multi-label classification, each data point can be associated with multiple labels simultaneously from a set of possible labels. This means a single data point can belong to multiple classes simultaneously. Examples of multi-label classification tasks include:

  • Document categorization (assigning topics/tags to a news article)
  • Image tagging (giving multiple labels to describe objects, attributes, or concepts in an image)
  • Music genre classification (giving multiple genres to a song)

Key Differences:

1. Label Assignment:

  • Multi-Class: Each data point is assigned to only one class label.
  • Multi-Label: Each data point can be set to multiple class labels.

2. Output Format:

  • Multi-Class: The output is a single predicted class label for each data point.
  • Multi-Label: The result is a binary vector indicating the presence or absence of each label for each data point.

3. Modelling Approach:

  • Multi-Class: Common algorithms include logistic regression, decision trees, random forests, SVMs, and neural networks.
  • Multi-Label: Algorithms must be adapted to handle multiple labels, such as binary relevance, label powerset, or neural network architectures designed for multi-label scenarios.

4. Evaluation Metrics:

  • Multi-Class: Metrics like accuracy, precision, recall, and F1-score are commonly used.
  • Multi-Label: Metrics might include subset accuracy, precision at k, recall at k, and F1-score at k, where k is the number of top predictions to consider.

5. Problem Complexity:

  • Multi-Class is generally considered more straightforward, as each data point has a correct label.
  • Multi-Label: More complex due to the potential combinations of labels for each data point.

Multi-class classification involves assigning one label to each data point, while multi-label classification involves giving multiple labels to each data point. The choice between these approaches depends on the nature of the problem and the data at hand.

Multi-class classification example

A common example used in machine learning for multi-class classification is detecting handwritten digits.

The goal is to classify handwritten digits from 0 to 9 based on their images. Each image contains a single digit, and the task is to predict the correct digit represented by the image.

Dataset: The MNIST dataset is a popular dataset for this problem. It contains 28×28 grayscale images of handwritten digits along with their corresponding labels.

Classes: The classes in this problem correspond to the digits 0 to 9, making it a 10-class classification task.


  1. Data Preparation: The dataset is divided into training and testing sets. Each image is converted into a numerical representation suitable for input to a machine learning model.
  2. Model Selection: Common models for this problem include neural networks, especially convolutional neural networks (CNNs), due to their ability to capture spatial patterns in images.
  3. Model Training: The selected model is trained on the training dataset. During training, the model learns to recognize patterns and features in the images indicative of the different digit classes.
  4. Model Evaluation: The model is evaluated on the testing dataset after training. The model’s predictions are compared to the ground truth labels to measure its accuracy and performance.
  5. Prediction: Once the model is trained and evaluated, it can be used to predict the digit labels for new handwritten images that were not part of the training or testing sets.
  6. Metrics: Common evaluation metrics for this multi-class classification problem include accuracy, precision, recall, and F1-score. Accuracy represents the proportion of correctly classified digits, while precision measures the fraction of correctly classified positive predictions among all positive predictions. Recall measures the fraction of correctly classified positive predictions among all actual positive instances. F1-score is the harmonic mean of precision and recall.
sorting post cards is a use of multi-class classification

Multi-class classification can be used by postal services for sorting mail with handwritten addresses

Applications: Handwritten digit recognition has various applications, including:

  • Postal services for sorting mail with handwritten addresses.
  • Banks for processing handwritten checks and forms.
  • Optical character recognition (OCR) systems for converting handwritten text into digital text.
  • The medical field for analyzing handwritten medical notes and prescriptions.

This example illustrates how multi-class classification can be applied to recognizing handwritten digits. Each digit is a distinct class, and the goal is to accurately predict the class label (digit) based on the input image data.

The common loss function for multi-class classification

Choosing a suitable loss function is crucial for effectively training a machine learning model in multi-class classification. The loss function quantifies the difference between the predicted class probabilities and the actual class labels, helping the model adjust its parameters during training. Several loss functions are commonly used for multi-class classification tasks:

Cross-Entropy Loss (Log Loss)

Cross-entropy loss, also known as log loss, is one of the most commonly used loss functions for multi-class classification. It measures the dissimilarity between the predicted class probabilities and the true class labels. For each data point, it sums the negative logarithm of the predicted probability for the true class label. The goal is to minimize this loss during training.

Mathematically, for a single data point with true label y and predicted class probabilities p, the cross-entropy loss is calculated as:

L = -∑(y_i * log(p_i))

where y_i is the indicator function that equals 1 if i is the true class label, and p_i is the predicted probability for class i.

Hinge Loss (for Support Vector Machines)

Hinge loss is commonly used with Support Vector Machines (SVMs) for multi-class classification. It aims to maximize the margin between classes while penalizing misclassifications.

Mathematically, for a single data point with true label y and predicted class scores s, the hinge loss is calculated as:

L =max(0, 1 - s_i + s_y), where i ≠ y

Here, s_i is the score for class Is_y is the score for the true class y, and the sum is taken over all classes except the true class.

Sparse Categorical Cross-Entropy Loss

This is an efficient variant of cross-entropy loss for situations where the true labels are represented as integers (indices) instead of one-hot encoded vectors. It computes the cross-entropy loss for a single data point, directly taking the true label index.

Kullback-Leibler Divergence (KL Divergence)

KL Divergence measures how one probability distribution differs from a second reference distribution. It can be used as a loss function for multi-class classification tasks, though it’s less commonly used than cross-entropy loss.

These are some of the primary loss functions used in multi-class classification. The choice of loss function often depends on the specific problem, the model architecture (e.g., neural network, SVM), and the desired properties of the optimization process during training. Cross-entropy loss is a general-purpose and widely-used choice due to its effectiveness in training models for multi-class classification tasks.

How to implement multi-class classification with Python

How to implement multi-class classification with sklearn

In this example, we will use the Iris dataset to classify iris flowers into three species based on their sepal and petal dimensions.

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Load the Iris dataset
iris = load_iris()
X =
y =

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create a logistic regression model for multi-class classification
model = LogisticRegression(max_iter=1000)

# Train the model, y_train)

# Make predictions on the test set
y_pred = model.predict(X_test)

# Calculate the accuracy of the model
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

In this example, we:

  1. Import the necessary libraries: load_iris for loading the dataset, train_test_split for splitting the data into training and testing sets, LogisticRegression for creating a logistic regression classifier, and accuracy_score for evaluating the model’s accuracy.
  2. Load the Iris dataset and split it into features (X) and target labels (y).
  3. Split the dataset into training and testing sets using the train_test_split function.
  4. Create a LogisticRegression model with the max_iter parameter set to 1000 to allow the model to converge.
  5. Train the model on the training data using the fit method.
  6. Make predictions on the test set using the trained model.
  7. Calculate the model’s accuracy by comparing the predicted and true labels using the accuracy_score function.

This example uses a simple logistic regression model. More complex models like decision trees, random forests, or neural networks might be more appropriate for more challenging multi-class classification tasks.

How to implement multi-class classification with PyTorch

1. Import Libraries: Start by importing the necessary libraries:

import torch 
import torch.nn as nn 
import torch.optim as optim 
import torchvision.transforms as transforms 
import torchvision.datasets as datasets from 
import DataLoader

2. Load and Preprocess Data: Load the dataset (such as CIFAR-10 or Fashion MNIST) and apply appropriate preprocessing steps:

transform = transforms.Compose([
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize images

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

3. Define Model: Create a neural network model using PyTorch’s nn.Module class. This example uses a simple architecture with convolutional and fully connected layers:

class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.Linear(128, num_classes)
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

4. Instantiate Model and Define Loss Function and Optimizer: Instantiate the model, define a loss function (e.g., cross-entropy), and choose an optimizer (e.g., SGD or Adam):

model = Classifier(num_classes=10) 
criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=0.001)

5. Training Loop: Iterate over the dataset, forward pass the data through the model, calculate the loss, backpropagate gradients, and update model weights:

num_epochs = 10
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

6. Model Evaluation: After training, you can evaluate the model’s performance on a test dataset:

test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * correct / total:.2f}%')

This example demonstrates multi-class classification using PyTorch. You can modify and expand upon this template to experiment with different architectures, datasets, and hyperparameters to develop accurate models for various classification tasks.

How to implement multi-class classification with Keras

In this example, we will use the Fashion MNIST dataset, which contains images of various clothing items categorized into ten classes. We’ll build and train a neural network for multi-class classification.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load the Fashion MNIST dataset
fashion_mnist = keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

# Normalize pixel values to between 0 and 1
X_train, X_test = X_train / 255.0, X_test / 255.0

# Split the dataset into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

# Define the neural network architecture
model = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')

# Compile the model

# Train the model, y_train, epochs=10, validation_data=(X_val, y_val))

# Evaluate the model on the test set
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test Accuracy: {test_accuracy:.4f}")

# Make predictions on a few examples
predictions = model.predict(X_test[:5])
predicted_labels = np.argmax(predictions, axis=1)
print("Predicted Labels:", predicted_labels)

In this example:

  1. We load the Fashion MNIST dataset, which contains 28×28 grayscale images of clothing items and their corresponding labels.
  2. We normalize the pixel values of the images to be between 0 and 1.
  3. We split the training data further into training and validation sets for model evaluation during training.
  4. We define a simple neural network architecture using Keras’s Sequential API. It consists of a flattening layer, a dense hidden layer with ReLU activation, and a dense output layer with softmax activation for multi-class classification.
  5. We compile the model using the Adam optimizer and sparse categorical cross-entropy loss, suitable for multi-class classification.
  6. We train the model on the training data and validate it using the validation set.
  7. We evaluate the model’s performance on the test set and print the test accuracy.
  8. We make predictions on a few examples from the test set and print the predicted labels.

This is a basic example to demonstrate multi-class classification using Keras. For more complex problems, you can consider using more advanced architectures, data augmentation, and hyperparameter tuning to achieve better results.


Multi-class classification is an essential machine learning task that involves assigning input data points to one of several predefined classes. It’s commonly encountered in various fields, such as image recognition, text categorization, and medical diagnosis. Understanding multi-class classification’s fundamental concepts and techniques is crucial for building effective and accurate predictive models.

Here’s a summary of the main points covered:

Multi-Class Classification vs Multi-Label Classification:

  • Multi-class classification assigns each data point to a single class out of several.
  • Multi-label classification gives each data point to multiple classes simultaneously.

Algorithms and Models:

  • Standard algorithms for multi-class classification include logistic regression, decision trees, random forests, support vector machines, and neural networks.

Loss Functions:

  • Cross-entropy loss (log loss) is a widely used loss function for multi-class classification tasks. It quantifies the difference between predicted class probabilities and actual class labels.

Evaluation Metrics:

  • Common metrics for evaluating multi-class classification include accuracy, precision, recall, F1-score, and confusion matrix.

Python Libraries and Tools:

  • Python libraries like scikit-learn, TensorFlow, Keras, and PyTorch provide tools and frameworks for building, training, and evaluating multi-class classification models.


  • An example of multi-class classification using Keras, PyTorch and Scikit-Learn was provided to illustrate the process.

Successful multi-class classification involves appropriate data preprocessing, model selection, hyperparameter tuning, and evaluation. As you delve into this topic, remember that practical experience and experimentation are vital to mastering multi-class classification and developing accurate predictive models for real-world tasks.

About the Author

Neri Van Otten

Neri Van Otten

Neri Van Otten is the founder of Spot Intelligence, a machine learning engineer with over 12 years of experience specialising in Natural Language Processing (NLP) and deep learning innovation. Dedicated to making your projects succeed.

Recent Articles

ROC curve

ROC And AUC Curves In Machine Learning Made Simple & How To Tutorial In Python

What are ROC and AUC Curves in Machine Learning? The ROC Curve The ROC (Receiver Operating Characteristic) curve is a graphical representation used to evaluate the...

decision boundaries for naive bayes

Naive Bayes Classification Made Simple & How To Tutorial In Python

What is Naive Bayes? Naive Bayes classifiers are a group of supervised learning algorithms based on applying Bayes' Theorem with a strong (naive) assumption that every...

One class SVM anomaly detection plot

How To Implement Anomaly Detection With One-Class SVM In Python

What is One-Class SVM? One-class SVM (Support Vector Machine) is a specialised form of the standard SVM tailored for unsupervised learning tasks, particularly anomaly...

decision tree example of weather to play tennis

Decision Trees In ML Complete Guide [How To Tutorial, Examples, 5 Types & Alternatives]

What are Decision Trees? Decision trees are versatile and intuitive machine learning models for classification and regression tasks. It represents decisions and their...

graphical representation of an isolation forest

Isolation Forest For Anomaly Detection Made Easy & How To Tutorial

What is an Isolation Forest? Isolation Forest, often abbreviated as iForest, is a powerful and efficient algorithm designed explicitly for anomaly detection. Introduced...

Illustration of batch gradient descent

Batch Gradient Descent In Machine Learning Made Simple & How To Tutorial In Python

What is Batch Gradient Descent? Batch gradient descent is a fundamental optimization algorithm in machine learning and numerical optimisation tasks. It is a variation...

Techniques for bias detection in machine learning

Bias Mitigation in Machine Learning [Practical How-To Guide & 12 Strategies]

In machine learning (ML), bias is not just a technical concern—it's a pressing ethical issue with profound implications. As AI systems become increasingly integrated...

text similarity python

Full-Text Search Explained, How To Implement & 6 Powerful Tools

What is Full-Text Search? Full-text search is a technique for efficiently and accurately retrieving textual data from large datasets. Unlike traditional search methods...

the hyperplane in a support vector regression (SVR)

Support Vector Regression (SVR) Simplified & How To Tutorial In Python

What is Support Vector Regression (SVR)? Support Vector Regression (SVR) is a machine learning technique for regression tasks. It extends the principles of Support...

1 Comment

  1. Farook Sattar

    Excellent and very useful article on commonly used multiclass classifiers in Python!!!


Submit a Comment

Your email address will not be published. Required fields are marked *

nlp trends

2024 NLP Expert Trend Predictions

Get a FREE PDF with expert predictions for 2024. How will natural language processing (NLP) impact businesses? What can we expect from the state-of-the-art models?

Find out this and more by subscribing* to our NLP newsletter.

You have Successfully Subscribed!