Graph Neural Network Explained & How To Tutorial In Python With PyTorch

by | Jun 30, 2023 | Artificial Intelligence, Machine Learning, Natural Language Processing

Graph Neural Network (GNN) is revolutionizing the field of machine learning by enabling effective modelling and analysis of structured data. Originally designed for graph-based data, GNNs have found application in various domains, including natural language processing (NLP). By incorporating the inherent structural dependencies in text data, GNNs offer a promising approach to tackling complex NLP tasks, such as text classification.

Text classification involves categorizing or assigning predefined labels to text documents based on their content. Traditional approaches often rely on bag-of-words or n-gram representations, which may not capture the rich semantic relationships between words. GNNs address this limitation by treating text data as a graph, where words are represented as nodes, and their relationships are modelled through edges or adjacency matrices.

In recent years, Graph Convolutional Networks (GCNs) have become a popular GNN architecture for text classification. GCNs enable information aggregation from neighbouring words, capturing contextual dependencies and semantic relationships in a text. By applying graph convolutions to word embeddings and leveraging the connectivity patterns among words, GCNs can generate more expressive representations for text documents, improving classification accuracy.

Using PyTorch, a popular deep learning library, implementing GCNs for text classification becomes more accessible. PyTorch provides a flexible framework for defining and training GNN models, allowing researchers and practitioners to leverage the power of GCNs in their NLP tasks. Combining the strengths of GNNs and PyTorch makes it possible to create sophisticated models that effectively capture the structural characteristics of text data, enabling more accurate and robust text classification.

In this article, we will delve into the application of Graph Convolutional Networks (GCNs) in text classification tasks. We will explore the architecture, data preparation, model implementation using PyTorch, and the potential benefits of leveraging GCNs for text classification. By harnessing the power of GNNs, we can unlock new avenues for understanding and analyzing text data, leading to advancements in various NLP applications.

What is a graph neural network (GNN)?

A graph neural network (GNN) is a neural network designed to process and analyze structured data represented as graphs. Unlike traditional neural networks that operate on grid-like or sequential data, GNNs can effectively capture the relationships and dependencies between elements in a graph.

A graph neural network (GNN) is a neural network designed to process and analyze structured data represented as graphs

A graph neural network is designed to process and analyze structured data represented as graphs.

Graphs consist of nodes (also called vertices) connected by edges (also called links). Each node in a graph can have attributes or features associated with it, and the edges represent relationships or connections between the nodes. For example, in a social network, nodes can represent individuals, and edges can represent friendships between them.

The main idea behind GNNs is to propagate information through the graph structure by iteratively updating node representations based on the features of their neighbouring nodes. This process allows GNNs to learn meaningful and context-aware node representations that capture local and global graph structures.

Typically, a GNN consists of multiple layers, each performing two main operations: message passing and aggregation. In the message-passing step, each node aggregates information from its neighbours and updates its representation. The aggregation step combines the updated representations of neighbouring nodes to obtain a refined representation for each node. These operations are performed iteratively across multiple layers, allowing the GNN to capture increasingly complex graph patterns.

GNNs have been successfully applied to various tasks, including node classification, link prediction, graph classification, recommendation systems, and molecule property prediction. They have shown promising results in domains where the data can be naturally represented as graphs, such as social networks, knowledge graphs, biological networks, and citation networks.

What is the graph neural network (GNN) architecture?

Graph neural network (GNN) architectures can vary depending on the specific task and the desired properties of the model. Here we will describe a commonly used GNN architecture, the Graph Convolutional Network (GCN), which forms the foundation for many other GNN variants.

Please note that there have been advancements and variations in GNN architectures beyond the GCN, but we will focus on describing the GCN architecture as a starting point.

The Graph Convolutional Network (GCN) architecture operates on a graph with N nodes and an adjacency matrix A. Here’s a step-by-step overview of the GCN architecture:

1. Input Representation: Each node in the graph is associated with a feature vector. These initial node features can be obtained from the nodes’ attributes or other sources. Let’s denote the feature matrix for the graph as X, where X ∈ R^(N x D), where N is the number of nodes, and D is the dimensionality of the node features.

2. Convolutional Layer: The convolutional layer is the core component of the GCN. Each node aggregates information from its neighbouring nodes and updates its representation in this layer. The update is performed by applying a graph convolution operation inspired by the convolution operation in traditional convolutional neural networks (CNNs).

  • In the GCN, the graph convolution operation can be defined as follows:

    H = σ(AXW)

    W is a learnable weight matrix, AX represents matrix multiplication, and σ is an activation function such as ReLU.

    The matrix multiplication AX represents the aggregation of neighbouring node features. AX is equivalent to multiplying the adjacency matrix A with the feature matrix X, resulting in a matrix of size (N x D), where each row corresponds to the aggregated features of a node’s neighbours. Multiplying this aggregated feature matrix with the weight matrix W and applying the activation function yields the updated node representations in matrix H.

3. Non-linear Activation: A non-linear activation function introduces non-linearity into the node representations after the convolutional layer. Common activation functions include ReLU, sigmoid, or tanh.

4. Pooling or Aggregation: You may want to aggregate the node representations into a graph-level representation depending on the task. This can be done by applying pooling or aggregation operations, such as summation, mean pooling, or graph-level attention mechanisms.

5. Output Layer: The graph-level representation can be fed into a fully connected layer or another type of classifier to produce the desired output, depending on the task. For example, in node classification tasks, the output layer may consist of a softmax function for predicting the class labels of each node.

The above steps can be repeated for multiple layers to capture increasingly complex graph patterns. Each layer’s output can serve as the input to the subsequent layer, allowing information to propagate through the graph structure.

Beyond the GCN, various other GNN architectures have been proposed, including GraphSAGE, GAT (Graph Attention Network), Graph Isomorphism Network (GIN), GraphSAGE, Graph Wavelet Neural Network (GWNN), and many more. These architectures may introduce additional components, such as attention mechanisms, skip connections, or higher-order operations, to enhance the model’s expressiveness or address specific challenges in graph learning tasks.

It’s important to note that the specific details of GNN architectures can vary based on the research paper or implementation you refer to. The architectural choices often depend on the particular task, the graph data characteristics, and the model’s desired properties.

What are the types of Graph Neural Network (GNN)?

There are several Graph Neural Networks (GNNs) types, each with architectural variations and characteristics. Here are some commonly used types of GNNs:

  1. Graph Convolutional Networks (GCNs): GCNs, as described earlier, are one of the foundational GNN architectures. They aggregate information from neighbouring nodes and update node representations through graph convolution operations.
  2. Graph Attention Networks (GATs): GATs leverage attention mechanisms to assign different weights to the neighbouring nodes during aggregation. This allows the model to focus on more informative nodes and perform adaptive neighbourhood aggregation.
  3. GraphSAGE: GraphSAGE (Graph Sample and Aggregated) is a variant of GNN that samples and aggregates features from the local neighbourhood of each node. It incorporates a “neighbourhood sampling” step, where a fixed-size neighbourhood is sampled for each node, followed by feature aggregation and updating node representations.
  4. Graph Isomorphism Network (GIN): GINs are designed to be permutation invariant and can operate on both directed and undirected graphs. They use multiple aggregation steps and apply a shared multi-layer perceptron (MLP) to aggregate features from the node’s neighbourhood.
  5. Graph Neural Networks with Recurrent Units (GRUs): GRUs extend GNN architectures by incorporating recurrent units to model temporal dependencies in sequential graph data. They are often used in tasks involving dynamic graphs or time-series graph data.
  6. Graph Wavelet Neural Networks (GWNNs): GWNNs leverage graph wavelet transforms to perform localized spectral analysis on graphs. They use wavelet filters to capture local and global graph structures, enabling effective information propagation and feature extraction.
  7. Graph Autoencoders: Graph autoencoders are unsupervised learning models that learn to encode graph data into low-dimensional embeddings and decode them back to reconstruct the original graph. They are useful for graph generation, anomaly detection, and representation learning.
  8. Graph Generative Models: These models focus on generating new graphs that possess similar properties to a given set of training graphs. Examples of graph-generative models include Graph Variational Autoencoders (Graph-VAEs), Graph Generative Adversarial Networks (Graph-GANs), and Graph Neural Networks for Graph Generation (GraphRNN).

These are just a few examples of GNN architectures, and the field of graph neural networks is rapidly evolving. Researchers continue developing new variants and adaptations of GNNs to address different tasks, improve performance, and handle various graph data types.

It’s worth noting that some GNN architectures can be combined or extended to create hybrid models or address specific challenges in graph learning tasks.

Applications of Graph Neural Networks (GNNs)

Graph Neural Networks (GNNs) have found application in various domains due to their ability to model and analyze structured data. Here are some notable applications of GNNs:

  1. Social Network Analysis: GNNs have been extensively used in social network analysis tasks. They can capture social relationships and community structures and influence propagation within networks. GNNs enable tasks such as node classification, link prediction, community detection, and recommendation systems in social networks.
  2. Knowledge Graphs: GNNs have proven effective in modelling and reasoning over knowledge graphs, representing structured information about entities and their relationships. GNNs can learn entity and relation embeddings, perform link prediction, entity classification, and question answering in knowledge graphs, improving tasks such as information retrieval and knowledge base completion.
  3. Recommendation Systems: GNNs have shown promise in recommendation systems by incorporating user-item interactions and modelling the underlying graph structure. They can capture user preferences and item relationships and generate personalized recommendations based on the learned embeddings. GNNs have been used in movie recommendations, e-commerce, and social media platforms.
  4. Computer Vision: GNNs have been successfully applied in computer vision tasks, particularly in graph-structured data scenarios. They can model object relationships and scene graphs and perform image segmentation, object detection, and action recognition tasks. GNNs in computer vision provide context-aware representations and capture dependencies between objects.
  5. Drug Discovery: GNNs have gained attention in drug discovery and materials science. They can model molecular structures and chemical interactions to predict molecular properties, screen drug candidates, and assist in drug design. GNNs enable the analysis of molecular graphs, capturing atom-level dependencies and structural patterns for more accurate predictions.
  6. Natural Language Processing (NLP): GNNs are increasingly applied in NLP tasks. They can model the relationships between words, sentences, or documents as graphs and capture semantic dependencies. GNNs have been used for sentiment analysis, named entity recognition, relation extraction, text classification, and document summarization.
  7. Bioinformatics: GNNs have found application in bioinformatics for tasks such as protein-protein interaction prediction, protein structure prediction, and genomics analysis. GNNs can model biological networks and capture dependencies between biomolecules, facilitating a better understanding and analysis of complex biological systems.

Graph Neural Network (GNN) and NLP

Graph Neural Networks (GNNs) can be applied to various natural language processing (NLP) tasks, enabling the modelling and analysis of structured relationships between words, sentences, or documents. Here are some ways GNNs are used in NLP:

  1. Semantic Role Labeling (SRL): GNNs can be employed to perform SRL, which involves identifying the predicate-argument structure in a sentence. GNNs can capture the dependencies and relationships between words and their corresponding roles to determine the semantic structure of a sentence.
  2. Relation Extraction: GNNs can be applied to relation extraction tasks, where the goal is to identify and classify relationships between entities in text. GNNs can effectively model the dependencies and interactions between entities by representing the words and their relationships as a graph, allowing for more accurate relation extraction.
  3. Sentiment Analysis: GNNs can perform sentiment analysis on text data by capturing the dependencies and contextual relationships between words. By representing sentences as graphs and propagating information through the graph structure, GNNs can learn representations that encode sentiment information and predict the sentiment expressed in the text.
  4. Text Classification: GNNs can be employed for text classification tasks, where the goal is to assign predefined categories or labels to text documents. By treating the documents as nodes in a graph and modelling their relationships, GNNs can capture the structural information and dependencies within the document collection, leading to improved text classification performance.
  5. Text Generation: GNNs can be used in text generation tasks, such as generating coherent and contextually relevant sentences or documents. GNNs can capture the dependencies and relationships between words and phrases, allowing for more informed and context-aware text generation.
  6. Question Answering: GNNs can be utilized for question-answering tasks, where the goal is to generate answers based on a given question and a context paragraph. By modelling the question, context, and candidate answers as a graph and propagating information through the graph structure, GNNs can capture the relevant relationships and dependencies, enabling more accurate and context-aware question answering.
  7. Document Summarization: GNNs can be applied to document summarization tasks, where the objective is to generate concise summaries of long documents. By representing the document as a graph and leveraging the graph structure, GNNs can capture important relationships between sentences or paragraphs, aiding in extracting salient information for summarization.

These are just a few examples of how GNNs can be applied to NLP tasks. GNNs offer a promising approach to leveraging graph-based representations and capturing structural dependencies in textual data, improving performance on various NLP tasks.

Graph Neural Network Tutorial with PyTorch

Graph Convolutional Networks for text classification

Graph Convolutional Networks (GCNs) can be adapted for text classification tasks by representing the text data as a graph and performing graph convolutions to capture the relationships between words. Here’s an outline of how to apply GCNs for text classification:

1. Data Preparation:

  • Preprocess the text data by tokenizing the documents into words or subwords.
  • Build a vocabulary by mapping each unique word to a numerical index.
  • Convert the text documents into sequences of word indices.
  • Create an adjacency matrix that represents the connections between words in the documents based on their co-occurrence or other relationships.

2. Define the GCN model class:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GCN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.conv1 = nn.Conv1d(embed_dim, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, output_dim, kernel_size=3, padding=1)

    def forward(self, x, adjacency_matrix):
        x = self.embedding(x)
        x = x.permute(0, 2, 1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.permute(0, 2, 1)
        x = torch.matmul(adjacency_matrix, x)
        x = x.mean(dim=1)
        return x

The GCN model includes an embedding layer (embedding) to convert the word indices into dense word embeddings. The convolutional layers (conv1 and conv2) perform graph convolutions, followed by activation functions. The adjacency matrix is multiplied by node features (x) to propagate information through the graph structure. Finally, the mean pooling operation is applied along the sequence dimension to obtain a fixed-length representation for each document.

3. Prepare the input data:

x = torch.tensor(document_sequences, dtype=torch.long)  # Tensor of document sequences
adjacency_matrix = torch.tensor(adjacency_matrix, dtype=torch.float32)  # Adjacency matrix
target_labels = torch.tensor(labels, dtype=torch.long)  # Tensor of target labels

4. Create an instance of the GCN model:

vocab_size = # Size of the vocabulary
embed_dim = # Dimensionality of word embeddings
hidden_dim = # Dimensionality of hidden layer
output_dim = # Dimensionality of output layer
model = GCN(vocab_size, embed_dim, hidden_dim, output_dim)

5. Define the loss function and optimizer:

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

6. Training loop:

for epoch in range(num_epochs):
    output = model(x, adjacency_matrix)
    loss = criterion(output, target_labels)

7. Evaluation:

with torch.no_grad():
    output = model(x, adjacency_matrix)
    predicted_labels = torch.argmax(output, dim=1)
    # Perform evaluation metrics or further processing

This basic outline of applying Graph Convolutional Networks (GCNs) for text classification. You can customize the model architecture, experiment with different hyperparameters, or incorporate additional layers to suit your specific task and requirements.


Graph Neural Network (GNN) has emerged as a powerful framework for modelling and analyzing structured data, including graphs and text data. They allow for capturing relationships and dependencies between elements in the data, enabling more effective representation learning and predictive modelling.

Graph Convolutional Networks (GCNs) can be adapted for text classification tasks by representing text data as a graph and performing graph convolutions to capture the relationships between words. By treating words as nodes and leveraging adjacency matrices, GCNs can propagate information through the graph structure and learn expressive representations for text documents. This approach can improve the performance of text classification models by capturing contextual and semantic relationships between words.

Implementing GCNs for text classification using PyTorch involves defining a model class with embedding and convolutional layers, preparing the input data with word sequences and adjacency matrices, and training the model using appropriate loss functions and optimization techniques. PyTorch provides a flexible and efficient framework for building and training GNN models.

It’s important to note that the field of GNNs is rapidly evolving, and there are various other GNN architectures and techniques beyond GCNs that can be explored for text classification and other NLP tasks. Researchers continue to develop new variations and advancements in GNNs, expanding their applications and improving their performance.

By leveraging the power of GNNs, researchers and practitioners can enhance their text classification models and achieve state-of-the-art performance by effectively modelling the complex relationships and dependencies present in text data.

About the Author

Neri Van Otten

Neri Van Otten

Neri Van Otten is the founder of Spot Intelligence, a machine learning engineer with over 12 years of experience specialising in Natural Language Processing (NLP) and deep learning innovation. Dedicated to making your projects succeed.

Recent Articles

ROC curve

ROC And AUC Curves In Machine Learning Made Simple & How To Tutorial In Python

What are ROC and AUC Curves in Machine Learning? The ROC Curve The ROC (Receiver Operating Characteristic) curve is a graphical representation used to evaluate the...

decision boundaries for naive bayes

Naive Bayes Classification Made Simple & How To Tutorial In Python

What is Naive Bayes? Naive Bayes classifiers are a group of supervised learning algorithms based on applying Bayes' Theorem with a strong (naive) assumption that every...

One class SVM anomaly detection plot

How To Implement Anomaly Detection With One-Class SVM In Python

What is One-Class SVM? One-class SVM (Support Vector Machine) is a specialised form of the standard SVM tailored for unsupervised learning tasks, particularly anomaly...

decision tree example of weather to play tennis

Decision Trees In ML Complete Guide [How To Tutorial, Examples, 5 Types & Alternatives]

What are Decision Trees? Decision trees are versatile and intuitive machine learning models for classification and regression tasks. It represents decisions and their...

graphical representation of an isolation forest

Isolation Forest For Anomaly Detection Made Easy & How To Tutorial

What is an Isolation Forest? Isolation Forest, often abbreviated as iForest, is a powerful and efficient algorithm designed explicitly for anomaly detection. Introduced...

Illustration of batch gradient descent

Batch Gradient Descent In Machine Learning Made Simple & How To Tutorial In Python

What is Batch Gradient Descent? Batch gradient descent is a fundamental optimization algorithm in machine learning and numerical optimisation tasks. It is a variation...

Techniques for bias detection in machine learning

Bias Mitigation in Machine Learning [Practical How-To Guide & 12 Strategies]

In machine learning (ML), bias is not just a technical concern—it's a pressing ethical issue with profound implications. As AI systems become increasingly integrated...

text similarity python

Full-Text Search Explained, How To Implement & 6 Powerful Tools

What is Full-Text Search? Full-text search is a technique for efficiently and accurately retrieving textual data from large datasets. Unlike traditional search methods...

the hyperplane in a support vector regression (SVR)

Support Vector Regression (SVR) Simplified & How To Tutorial In Python

What is Support Vector Regression (SVR)? Support Vector Regression (SVR) is a machine learning technique for regression tasks. It extends the principles of Support...


Submit a Comment

Your email address will not be published. Required fields are marked *

nlp trends

2024 NLP Expert Trend Predictions

Get a FREE PDF with expert predictions for 2024. How will natural language processing (NLP) impact businesses? What can we expect from the state-of-the-art models?

Find out this and more by subscribing* to our NLP newsletter.

You have Successfully Subscribed!