Understanding Generative Adversarial Network With A How To Tutorial In TensorFlow And Python

by | Mar 8, 2023 | Artificial Intelligence, Machine Learning

What is a Generative Adversarial Network (GAN)? What are they used for? How do they work? And what different types are there? This article includes a tutorial on how to get started with GANs in TensorFlow and Python.

What is a Generative Adversarial Network (GAN)?

Generative Adversarial Network (GAN) are a type of deep learning model consisting of two neural networks, a generator and a discriminator, that work in opposition. On the one hand, the generator learns to make new examples that resemble input data. On the other hand, the discriminator learns to tell the difference between real examples and ones made by the generator.

Random noise is fed into the generator, making fake data look like real data. The discriminator takes both the real data and the made-up data as input and guesses whether each example is real or fake. The generator then learns to make better outputs by making the discriminator think the data it creates is real.

During training, the two networks are trained simultaneously, with the discriminator being trained to identify real data and generate data correctly. The generator is trained to generate data difficult for the discriminator to distinguish from real data.

GANs have been used to make real-looking pictures, videos, and music and to produce data for other uses. Like adding to data for machine learning tasks.

But GAN training can be problematic because training the two networks simultaneously can be hard. Problems like mode collapse and situations where the generator only learns a few outputs can be difficult to fix.

What does a GAN network look like?

A GAN network consists of two neural networks: a generator and a discriminator. The generator takes random noise as input and generates fake data that resembles the input data. The discriminator takes both real data and generated data as input and predicts whether each example is real or fake.

Most of the time, the generator comprises multiple layers of fully connected or convolutional neural networks. These networks take the random noise into a higher-dimensional space that looks like the data that went into it. The output of the generator is then fed into the discriminator.

The discriminator also has multiple layers of fully connected or convolutional neural networks that learn to tell the difference between actual data and data made by the generator. Finally, the output of the discriminator is a probability score that indicates how confident it is that the input data is real.

During training, the two networks are trained simultaneously, with the discriminator being trained to identify real data and generate data. On the other hand, the generator is being trained to create data that is difficult for the discriminator to distinguish from real data. This is done by minimising a loss function that balances the objectives of the two networks.

Overall, the GAN network architecture allows the generation of new data similar to the input data and can be used for various applications, such as image and video synthesis, data augmentation, and anomaly detection.

What are the types of GAN models?

Various types of Generative Adversarial Network (GAN) have been developed, each with its unique approach and architecture. Some of the common types of GANs include:

  1. Vanilla GANs: Vanilla GANs are the simplest and most basic form of GANs. They consist of two neural networks, a generator and a discriminator, that compete against each other in a game to generate realistic-looking images.
  2. Deep Convolutional GANs (DCGANs): DCGANs use convolutional neural networks (CNNs) for both the generator and discriminator. They are known for their ability to generate high-resolution and realistic images.
  3. Conditional GANs (cGANs): cGANs are a type of GAN that takes additional input data, called conditioning data, along with the noise vector to generate images that meet a specific condition. They are commonly used for tasks such as image-to-image translation and style transfer.
  4. Wasserstein GANs (WGANs): WGANs are a variant of GANs that use the Wasserstein distance to measure the difference between the generated and real data distributions. They are more stable and create higher-quality images than traditional GANs.
  5. CycleGANs: CycleGANs are a type of GAN that can learn to translate between two domains, such as converting images of horses into pictures of zebras. They work by learning a mapping between the two domains without the need for paired training data.
  6. Progressive GANs: Progressive GANs make high-quality images by slowly pushing the resolution of the pictures they make higher and higher. They start by making low-resolution images and gradually improve until they reach the desired size.
  7. StyleGANs: StyleGANs generate images by separating the style and content of an image. They allow for the control of specific features in the generated images, such as hair colour or facial expression.

Conditional GAN

A popular GAN is the Conditional Generative Adversarial Network (cGAN). These GAN types take additional input data, called conditioning data, and the noise vector to generate images that meet a specific condition. The conditioning data could be any information, such as class labels, text descriptions, or pictures.

In cGANs, the random noise vector and the conditioning data are inputs to the generator network, making an output that matches. The discriminator network then takes both the generated output and the conditioning data and tries to determine whether the output is real or fake.

cGANs are often used for tasks like image-to-image translation, where the goal is to make an image that matches the given input. For example, cGANs can be used to translate a grayscale image into a coloured image or to remove noise from an image.

Another application of cGANs is style transfer, where the conditioning data is a style image. The goal is to generate an output image with the same style as the conditioning image. This technique can create artwork or modify an existing image’s style.

AI art generated by a conditional GAN

AI art generated by a conditional GAN

Overall, cGANs give you more control over the outputs by letting you use conditioning data to say what outcome you want.

Generative Adversarial Networks applications

Generative Adversarial Network (GAN) has been widely used in various applications, some of which include:

  1. Image generation: GANs have been used to generate realistic images that can be used for various applications, such as video games, virtual reality, and movie production.
  2. Image-to-image translation: GANs have been used to translate images from one domain to another, such as converting a black-and-white photo to a coloured image.
  3. Video generation: GANs have been used to generate videos that can be used for various applications, such as movie production and video editing.
  4. Text-to-image synthesis: GANs have been used to generate images from text descriptions. This technique can be used to generate images for product catalogues or to create pictures for specific scenarios.
  5. Style transfer: GANs have been used to transfer the style of one image to another. This technique can be used for image editing or to generate artwork.
  6. Anomaly detection: GANs have been used to detect anomalies in data, such as identifying fraudulent transactions or detecting defects in manufacturing processes.
  7. Data augmentation: GANs have been used to augment training data for machine learning models. This technique can improve the performance of machine learning models.

Overall, GANs have shown much promise in many different settings and have become an essential tool for researchers and practitioners in machine learning and artificial intelligence.

Generative Adversarial networks art

Generative Adversarial Network (GAN) have become a popular tool for creating art and have been used in various artistic applications, such as generating images, videos, and music.

One example of GAN art is the creation of realistic portrait images. Artists have trained GANs on datasets of portrait images and used the resulting generator network to create new, unique portraits that look like they could be real people. These images can be used for various purposes, such as developing character designs for movies or video games.

Another example of GAN art is the creation of surreal or abstract images. Artists have trained GANs on datasets of abstract art or other non-representational imagery and used the resulting generator network to create new, unique images that push the boundaries of what is possible with traditional art techniques.

Surreal image generated by GANs

Surreal image generated by GANs

GANs have also been used to create videos, such as music videos or short films. In addition, artists have trained GANs on datasets of video footage and used the resulting generator network to develop new, unique videos that can be used for artistic or entertainment purposes.

Overall, GANs have given artists and other creative people new opportunities. They can now try out new techniques and make unique, one-of-a-kind works of art that were impossible to make before.

Generative Adversarial networks tutorial in Python

Here’s an example of implementing a simple Generative Adversarial Network (GAN) using TensorFlow.

First, let’s import the necessary libraries:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

Next, let’s define the generator and discriminator models:

# Generator model
def make_generator_model():
    model = keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# Discriminator model
def make_discriminator_model():
    model = keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Following on, let’s define the loss functions for both the generator and discriminator:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Next, let’s define the optimizer for both the generator and discriminator:

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Now, let’s train the GAN:

# Load the MNIST dataset
(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

# Set the batch size and number of epochs
BUFFER_SIZE = 60000
BATCH_SIZE = 256
EPOCHS = 100

# Create the generator
generator = make_generator_model()

# Create the discriminator
discriminator = make_discriminator_model()

#Define the training loop
@tf.function
def train_step(images):
  # Generate random noise
  noise = tf.random.normal([BATCH_SIZE, 100])
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      # Generate images using the generator
      generated_images = generator(noise, training=True)
  
      # Get the discriminator's predictions for the real and generated images
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
  
      # Calculate the loss for both the generator and discriminator
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

  # Calculate the gradients for the generator and discriminator
  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

# Create a function to generate and save images
def generate_and_save_images(model, epoch, test_input):
  # Generate images from the model
  predictions = model(test_input, training=False)
  # Rescale the pixel values to [0, 1]
  predictions = (predictions + 1) / 2.0
  
  # Create a plot to display the images
  fig = plt.figure(figsize=(4, 4))
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0], cmap='gray')
      plt.axis('off')
  
  # Save the plot
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
  
# Create a function to train the GAN
def train(dataset, epochs):
  # Generate a fixed noise vector to use for visualization
  test_input = tf.random.normal([16, 100])
  for epoch in range(epochs):
    for image_batch in dataset:
        # Train the discriminator
        train_step(image_batch)

    # Generate and save images every 10 epochs
    if (epoch + 1) % 10 == 0:
        generate_and_save_images(generator, epoch + 1, test_input)

    print('Epoch {} completed'.format(epoch + 1))

  # Generate a final set of images and save them
  generate_and_save_images(generator, epochs, test_input)
  
# Load the dataset and create batches
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Train the GAN
train(train_dataset, EPOCHS)

And that’s it! This code should train a simple GAN on the MNIST dataset and save generated images every 10 epochs.

Note that there are many ways to modify this code to improve the performance and stability of the GAN, but this should serve as a good starting point.

Summary code explanation

Here’s a brief overview of what each part of the code is doing:

  1. Define the generator and discriminator: We define the generator and discriminator models using the make_generator_model and make_discriminator_model functions.
  2. Define the training loop: We define a training loop using the @tf.function decorator to improve performance. Inside the training loop, we generate random noise, pass it through the generator to generate fake images, and then pass both real and fake images through the discriminator to get its predictions. We then calculate the loss for both the generator and discriminator and use the tf.GradientTape() context manager to compute gradients. Finally, we apply the gradients to the optimizer using the apply_gradients method.
  3. Define the generate_and_save_images function: This function takes in the generator model, the current epoch number, and a fixed noise vector, and generates a set of images using the generator model. It then rescales the pixel values to [0, 1], creates a plot to display the images, and saves the plot to a file.
  4. Define the train function: This function takes in the training dataset and the number of epochs to train for, and trains the GAN using the training loop defined earlier. It also calls the generate_and_save_images function every 10 epochs to generate and save a set of generated images.
  5. Load the dataset and create batches: We load the MNIST dataset and create batches using the tf.data.Dataset API.
  6. Train the GAN: We call the train function to train the GAN on the MNIST dataset.

Again, note that there are many ways to modify this code to improve the performance and stability of the GAN, but this should serve as a good starting point.

Conclusion

In conclusion, Generative Adversarial Networks (GANs) are a class of deep learning models used to generate synthetic data that mimics real data. GANs consist of two neural networks, a generator and a discriminator, that compete with each other in a game-like setting. The generator produces synthetic data, while the discriminator distinguishes between synthetic and real data. The two networks continue to improve until the generator produces data indistinguishable from the real data.

GANs have been successfully used in various applications, such as image synthesis, text-to-image translation, and music generation. GANs have also improved data quality, such as upscaling low-resolution images or converting black and white images to colour.

However, GANs can be challenging to train and require large amounts of data. They are also susceptible to producing biased results if the training data is biased. Researchers are actively developing new techniques to address these challenges and improve the effectiveness of GANs.

Overall, GANs are a powerful tool for generating synthetic data that can be used for various applications. As research advances, we can expect to see even more exciting and innovative applications of GANs in the future.

About the Author

Neri Van Otten

Neri Van Otten

Neri Van Otten is the founder of Spot Intelligence, a machine learning engineer with over 12 years of experience specialising in Natural Language Processing (NLP) and deep learning innovation. Dedicated to making your projects succeed.

Recent Articles

distributional semantics example

Distributional Semantics Simplified & 7 Techniques [How To Understand Language]

What is Distributional Semantics? Understanding the meaning of words has always been a fundamental challenge in natural language processing (NLP). How do we decipher...

4 common regression metrics

10 Regression Metrics For Machine Learning & Practical How To Guide

What are Evaluation Metrics for Regression Models? Regression analysis is a fundamental tool in statistics and machine learning used to model the relationship between a...

find the right document

Natural Language Search Explained [10 Powerful Tools & How To Tutorial In Python]

What is Natural Language Search? Natural language search refers to the capability of search engines and other information retrieval systems to understand and interpret...

the difference between bagging, boosting and stacking

Bagging, Boosting & Stacking Made Simple [3 How To Tutorials In Python]

What is Bagging, Boosting and Stacking? Bagging, boosting and stacking represent three distinct ensemble learning techniques used to enhance the performance of machine...

y_actual - y_predicted

Top 9 Performance Metrics In Machine Learning & How To Use Them

Why Do We Need Performance Metrics In Machine Learning? In machine learning, the ultimate goal is to develop models that can accurately generalize to unseen data and...

This stochasticity imbues SGD with the ability to traverse the optimization landscape more dynamically, potentially avoiding local minima and converging to better solutions.

Stochastic Gradient Descent (SGD) In Machine Learning Explained & How To Implement

Understanding Stochastic Gradient Descent (SGD) In Machine Learning Stochastic Gradient Descent (SGD) is a pivotal optimization algorithm widely utilized in machine...

self attention example in BERT NLP

The BERT Algorithm (NLP) Made Simple [Understand How Large Language Models (LLMs) Work]

What is BERT in the context of NLP? In Natural Language Processing (NLP), the quest for models genuinely understanding and generating human language has been a...

fact checking with large language models LLMs

Fact-Checking With Large Language Models (LLMs): Is It A Powerful NLP Verification Tool?

Can a Machine Tell a Lie? Picture this: you're scrolling through social media, bombarded by claims about the latest scientific breakthrough, political scandal, or...

key elements of cognitive computing

Cognitive Computing Made Simple: Powerful Artificial Intelligence (AI) Capabilities & Examples

What is Cognitive Computing? The term "cognitive computing" has become increasingly prominent in today's rapidly evolving technological landscape. As our society...

0 Comments

Submit a Comment

Your email address will not be published. Required fields are marked *

nlp trends

2024 NLP Expert Trend Predictions

Get a FREE PDF with expert predictions for 2024. How will natural language processing (NLP) impact businesses? What can we expect from the state-of-the-art models?

Find out this and more by subscribing* to our NLP newsletter.

You have Successfully Subscribed!