How To Implement Logistic Regression Text Classification In Python With Scikit-learn and PyTorch

by | Feb 22, 2023 | Data Science, Machine Learning, Natural Language Processing

Text classification is a fundamental problem in natural language processing (NLP) that involves categorising text data into predefined classes or categories. It can be used in many real-world situations, like sentiment analysis, spam filtering, topic modelling, and content classification, to name a few.

While logistic regression has some limitations, such as the assumption of a linear relationship between the input features and the class labels, it remains a useful and practical approach to text classification.

Logistic regression

Logistic regression assumes a linear relationship between the input features and the class labels

In this article, we’ll talk about how to use logistic regression to classify text in Python in two distinct ways. Firstly, how to do it in scikit-learn and secondly, using PyTorch. We also discuss how to deal with multiple classes and go into some alternatives to logistic regression.

Why use logistic regression?

Logistic regression is a popular algorithm for text classification and is also our go-to favourite for several reasons:

  1. Simplicity: Logistic regression is a relatively simple algorithm that is easy to implement and interpret. It can be trained efficiently even on large datasets, making it a practical choice for many real-world applications.
  2. Easily understood: Logistic regression models can be understood by looking at the coefficients of the input features, which can show which words or phrases are most important for classification.
  3. Works well with sparse data: Text data is often very high-dimensional and sparse, meaning many features are zero for most data points. Logistic regression can handle sparse data well and can be regularised to prevent overfitting.
  4. Versatile: Logistic regression works well for both binary and multi-class classification. It is a versatile algorithm for text classification that can be used for binary and multi-class classification tasks.
  5. Baseline model: Logistic regression can be used as a baseline model for classifying text. This lets you compare how well more complicated algorithms work with a simple model that is easy to understand.

Logistic regression is a practical algorithm for classifying text that can give good results in many situations, especially for more straightforward classification tasks or as a starting point for more complicated algorithms.

How to use logistic regression for text classification

Logistic regression is a commonly used statistical method for binary classification tasks, including text classification.

In text classification, the goal is to assign a given piece of text to one or more predefined categories or classes.

To use logistic regression for text classification, we first need to represent the text as numerical features that can be used as input to the model. One popular approach for this is to use the bag-of-words representation, where we represent each document as a vector of word frequencies.

Once we have our numerical feature representation of the text, we can use logistic regression to learn a model to predict the probability of each document belonging to a given class. The logistic regression model learns a set of weights for each feature and uses these weights to make predictions based on the input features.

During training, we adjust the weights to minimise a loss function, such as cross-entropy, that measures the difference between the predicted probabilities and the actual labels. Once the model is trained, we can use it to predict the class labels for new text inputs.

Overall, logistic regression is a simple but effective method for text classification tasks and can be used as a baseline model or combined with more complex models in ensemble approaches. However, it may need help with more complex relationships between features and labels and may not capture the full range of patterns in natural language data.

Logistic regression may not capture the more complex relationships in natural language for text classification.

Logistic regression may not capture the more complex relationships in natural language

Logistic regression for multi-class text classification

Logistic regression can also be used for multi-class text classification tasks, assigning a given text to one of several possible classes or categories. For classifying more than two groups, logistic regression can be used in two main ways: one-vs-all (OvA) and softmax regression.

In the OvA approach, we train one binary logistic regression classifier for each class, where the positive class is the target class, and the negative class is all other classes. Then, during prediction, we calculate the probability of each class using all the binary classifiers and select the category with the highest probability as the predicted class.

In the softmax regression approach, we use a single logistic regression model with a softmax activation function, which can output the probability of each class as a vector of values that sum to 1. During training, we adjust the model weights to maximise the log-likelihood of the correct class label given the input features. Then, during prediction, we select the class with the highest probability as the predicted class.

Both approaches can be practical for multi-class text classification, and the choice between them depends on the specific task and data.

For example, Softmax regression is usually easier to run on a computer and can deal better with class imbalances. On the other hand, OvA may be more robust to noisy data and easier to understand.

Logistic regression text classification in Python

Sklearn logistic regression text classification

To perform text classification using logistic regression in Python, you can use popular libraries such as scikit-learn and nltk. Here is an example of how to use scikit-learn and logistic regression to classify binary text using CountVectorizer:

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

# Load the dataset
texts = [...]  # list of text samples
labels = [...]  # list of corresponding labels (0 or 1)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42)

# Convert the text to a bag-of-words representation
vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(X_train)
X_test = vectorizer.transform(X_test)

# Train a logistic regression classifier
classifier = LogisticRegression(), y_train)

# Make predictions on the test set
y_pred = classifier.predict(X_test)

# Evaluate the accuracy of the classifier
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

In this example, we first load the text dataset and split it into training and test sets. We then use the CountVectorizer class from scikit-learn to convert the text to a bag-of-words representation.

Next, we train a logistic regression classifier using the LogisticRegression class from scikit-learn and make predictions on the test set. Finally, we evaluate the accuracy of the classifier using the accuracy_score function from scikit-learn.

Logistic regression text classification PyTorch example

To perform text classification using logistic regression in PyTorch, you can use the torchtext library to load and preprocess the text data and then use PyTorch’s nn and optim modules to define and train the logistic regression model. Here’s an example of how to do this:

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.legacy import data
from torchtext.legacy import datasets

# Define the field for the text data
TEXT = data.Field(tokenize='spacy')

# Load the IMDB dataset and split it into training and test sets
train_data, test_data = datasets.IMDB.splits(TEXT)

# Build the vocabulary from the training data

# Define the logistic regression model
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.linear(x)

# Set the device to use for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set the hyperparameters
input_dim = len(TEXT.vocab)
output_dim = 2
lr = 0.001
batch_size = 64
epochs = 10

# Define the iterators for the training and test sets
train_iter, test_iter = data.BucketIterator.splits(
    (train_data, test_data), batch_size=batch_size, device=device)

# Initialize the model and optimizer
model = LogisticRegression(input_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train the model
for epoch in range(epochs):
    for batch in train_iter:
        x =
        y =
        y_pred = model(x)
        loss = nn.CrossEntropyLoss()(y_pred, y)
    # Evaluate the model on the test set
    with torch.no_grad():
        correct = 0
        total = 0
        for batch in test_iter:
            x =
            y =
            y_pred = model(x)
            _, predicted = torch.max(, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
        accuracy = 100 * correct / total
        print('Epoch: {}, Test Accuracy: {}%'.format(epoch+1, accuracy))

In this example, we first define the field for the text data using torchtext’s data.Field class, and then load and split the IMDB dataset into training and test sets. We build the vocabulary from the training data using the TEXT.build_vocab method. We then define the logistic regression model as a subclass of nn module, and initialise the model and optimiser. We define the iterators for the training and test sets using torchtext’s data.BucketIterator class, train the model using a loop over the training set. Finally, we evaluate the model on the test set and print the test accuracy for each epoch.

Alternatives to logistic regression – taking it one step further

Once you have your logistic regression working, the next logical step is to try another algorithm.

There are several alternatives to logistic regression for text classification, some of which include:

  1. Naive Bayes: Naive Bayes is a probabilistic classifier commonly used for text classification tasks. It assumes that the features (i.e., the words) are conditionally independent given the class label and uses Bayes’ theorem to calculate the probability of each class given the features.
  2. Support Vector Machines (SVMs): SVMs are a popular class of machine learning algorithms that can be used for text classification. They work by finding the hyperplane that maximally separates the data into different classes. It can be used with various kernel functions to map the input features into a higher-dimensional space.
  3. Decision Trees: Decision trees are supervised learning algorithms that can be used for text classification. They work by recursively partitioning the feature space into smaller regions based on the values of the input features and assigning a label to each part based on the majority class of the training data in that region.
  4. Neural Networks: Neural networks are a class of machine learning algorithms inspired by the human brain’s structure and function. They can be used for various text classification tasks, including binary and multi-class classification.
  5. Random Forests: Random forests are an ensemble learning method that combines multiple decision trees to improve the accuracy of the classifier. They can be used for text classification by training numerous decision trees on different subsets of the feature space and then combining their predictions to make the final classification decision.

Each of these alternatives has its strengths and weaknesses, and the choice of which algorithm to use will depend on the specific characteristics of the text classification task at hand.


Logistic regression is a popular algorithm for text classification due to its simplicity, interpretability, ability to handle sparse data, versatility for binary and multi-class classification tasks, and usefulness as a baseline model. Even though there are other ways to classify text besides logistic regression, each with its pros and cons, logistic regression can often give good results and is a good choice for many real-world applications.

About the Author

Neri Van Otten

Neri Van Otten

Neri Van Otten is the founder of Spot Intelligence, a machine learning engineer with over 12 years of experience specialising in Natural Language Processing (NLP) and deep learning innovation. Dedicated to making your projects succeed.

Recent Articles

online machine learning process

Online Machine Learning Explained & How To Build A Powerful Adaptive Model

What is Online Machine Learning? Online machine learning, also known as incremental or streaming learning, is a type of machine learning in which models are updated...

data drift in machine learning over time

Data Drift In Machine Learning Explained: How To Detect & Mitigate It

What is Data Drift Machine Learning? In machine learning, the accuracy and effectiveness of models heavily rely on the quality and consistency of the data on which they...

precision and recall explained

Classification Metrics In Machine Learning Explained & How To Tutorial In Python

What are Classification Metrics in Machine Learning? In machine learning, classification tasks are omnipresent. From spam detection in emails to medical diagnosis and...

example of a co-occurance matrix for NLP

Co-occurrence Matrices Explained: How To Use Them In NLP, Computer Vision & Recommendation Systems [6 Tools]

What are Co-occurrence Matrices? Co-occurrence matrices serve as a fundamental tool across various disciplines, unveiling intricate statistical relationships hidden...

use cases of query understanding

Query Understanding In NLP Simplified & How It Works [5 Techniques]

What is Query Understanding? Understanding user queries lies at the heart of efficient communication between humans and machines in the vast digital information and...

distributional semantics example

Distributional Semantics Simplified & 7 Techniques [How To Understand Language]

What is Distributional Semantics? Understanding the meaning of words has always been a fundamental challenge in natural language processing (NLP). How do we decipher...

4 common regression metrics

10 Regression Metrics For Machine Learning & Practical How To Guide

What are Evaluation Metrics for Regression Models? Regression analysis is a fundamental tool in statistics and machine learning used to model the relationship between a...

find the right document

Natural Language Search Explained [10 Powerful Tools & How To Tutorial In Python]

What is Natural Language Search? Natural language search refers to the capability of search engines and other information retrieval systems to understand and interpret...

the difference between bagging, boosting and stacking

Bagging, Boosting & Stacking Made Simple [3 How To Tutorials In Python]

What is Bagging, Boosting and Stacking? Bagging, boosting and stacking represent three distinct ensemble learning techniques used to enhance the performance of machine...


Submit a Comment

Your email address will not be published. Required fields are marked *

nlp trends

2024 NLP Expert Trend Predictions

Get a FREE PDF with expert predictions for 2024. How will natural language processing (NLP) impact businesses? What can we expect from the state-of-the-art models?

Find out this and more by subscribing* to our NLP newsletter.

You have Successfully Subscribed!