Understanding Generative Adversarial Network With A How To Tutorial In TensorFlow And Python

by | Mar 8, 2023 | Artificial Intelligence, Machine Learning

What is a Generative Adversarial Network (GAN)? What are they used for? How do they work? And what different types are there? This article includes a tutorial on how to get started with GANs in TensorFlow and Python.

What is a Generative Adversarial Network (GAN)?

Generative Adversarial Network (GAN) are a type of deep learning model consisting of two neural networks, a generator and a discriminator, that work in opposition. On the one hand, the generator learns to make new examples that resemble input data. On the other hand, the discriminator learns to tell the difference between real examples and ones made by the generator.

Random noise is fed into the generator, making fake data look like real data. The discriminator takes both the real data and the made-up data as input and guesses whether each example is real or fake. The generator then learns to make better outputs by making the discriminator think the data it creates is real.

During training, the two networks are trained simultaneously, with the discriminator being trained to identify real data and generate data correctly. The generator is trained to generate data difficult for the discriminator to distinguish from real data.

GANs have been used to make real-looking pictures, videos, and music and to produce data for other uses. Like adding to data for machine learning tasks.

But GAN training can be problematic because training the two networks simultaneously can be hard. Problems like mode collapse and situations where the generator only learns a few outputs can be difficult to fix.

What does a GAN network look like?

A GAN network consists of two neural networks: a generator and a discriminator. The generator takes random noise as input and generates fake data that resembles the input data. The discriminator takes both real data and generated data as input and predicts whether each example is real or fake.

Most of the time, the generator comprises multiple layers of fully connected or convolutional neural networks. These networks take the random noise into a higher-dimensional space that looks like the data that went into it. The output of the generator is then fed into the discriminator.

The discriminator also has multiple layers of fully connected or convolutional neural networks that learn to tell the difference between actual data and data made by the generator. Finally, the output of the discriminator is a probability score that indicates how confident it is that the input data is real.

During training, the two networks are trained simultaneously, with the discriminator being trained to identify real data and generate data. On the other hand, the generator is being trained to create data that is difficult for the discriminator to distinguish from real data. This is done by minimising a loss function that balances the objectives of the two networks.

Overall, the GAN network architecture allows the generation of new data similar to the input data and can be used for various applications, such as image and video synthesis, data augmentation, and anomaly detection.

What are the types of GAN models?

Various types of Generative Adversarial Network (GAN) have been developed, each with its unique approach and architecture. Some of the common types of GANs include:

  1. Vanilla GANs: Vanilla GANs are the simplest and most basic form of GANs. They consist of two neural networks, a generator and a discriminator, that compete against each other in a game to generate realistic-looking images.
  2. Deep Convolutional GANs (DCGANs): DCGANs use convolutional neural networks (CNNs) for both the generator and discriminator. They are known for their ability to generate high-resolution and realistic images.
  3. Conditional GANs (cGANs): cGANs are a type of GAN that takes additional input data, called conditioning data, along with the noise vector to generate images that meet a specific condition. They are commonly used for tasks such as image-to-image translation and style transfer.
  4. Wasserstein GANs (WGANs): WGANs are a variant of GANs that use the Wasserstein distance to measure the difference between the generated and real data distributions. They are more stable and create higher-quality images than traditional GANs.
  5. CycleGANs: CycleGANs are a type of GAN that can learn to translate between two domains, such as converting images of horses into pictures of zebras. They work by learning a mapping between the two domains without the need for paired training data.
  6. Progressive GANs: Progressive GANs make high-quality images by slowly pushing the resolution of the pictures they make higher and higher. They start by making low-resolution images and gradually improve until they reach the desired size.
  7. StyleGANs: StyleGANs generate images by separating the style and content of an image. They allow for the control of specific features in the generated images, such as hair colour or facial expression.

Conditional GAN

A popular GAN is the Conditional Generative Adversarial Network (cGAN). These GAN types take additional input data, called conditioning data, and the noise vector to generate images that meet a specific condition. The conditioning data could be any information, such as class labels, text descriptions, or pictures.

In cGANs, the random noise vector and the conditioning data are inputs to the generator network, making an output that matches. The discriminator network then takes both the generated output and the conditioning data and tries to determine whether the output is real or fake.

cGANs are often used for tasks like image-to-image translation, where the goal is to make an image that matches the given input. For example, cGANs can be used to translate a grayscale image into a coloured image or to remove noise from an image.

Another application of cGANs is style transfer, where the conditioning data is a style image. The goal is to generate an output image with the same style as the conditioning image. This technique can create artwork or modify an existing image’s style.

AI art generated by a conditional GAN

AI art generated by a conditional GAN

Overall, cGANs give you more control over the outputs by letting you use conditioning data to say what outcome you want.

Generative Adversarial Networks applications

Generative Adversarial Network (GAN) has been widely used in various applications, some of which include:

  1. Image generation: GANs have been used to generate realistic images that can be used for various applications, such as video games, virtual reality, and movie production.
  2. Image-to-image translation: GANs have been used to translate images from one domain to another, such as converting a black-and-white photo to a coloured image.
  3. Video generation: GANs have been used to generate videos that can be used for various applications, such as movie production and video editing.
  4. Text-to-image synthesis: GANs have been used to generate images from text descriptions. This technique can be used to generate images for product catalogues or to create pictures for specific scenarios.
  5. Style transfer: GANs have been used to transfer the style of one image to another. This technique can be used for image editing or to generate artwork.
  6. Anomaly detection: GANs have been used to detect anomalies in data, such as identifying fraudulent transactions or detecting defects in manufacturing processes.
  7. Data augmentation: GANs have been used to augment training data for machine learning models. This technique can improve the performance of machine learning models.

Overall, GANs have shown much promise in many different settings and have become an essential tool for researchers and practitioners in machine learning and artificial intelligence.

Generative Adversarial networks art

Generative Adversarial Network (GAN) have become a popular tool for creating art and have been used in various artistic applications, such as generating images, videos, and music.

One example of GAN art is the creation of realistic portrait images. Artists have trained GANs on datasets of portrait images and used the resulting generator network to create new, unique portraits that look like they could be real people. These images can be used for various purposes, such as developing character designs for movies or video games.

Another example of GAN art is the creation of surreal or abstract images. Artists have trained GANs on datasets of abstract art or other non-representational imagery and used the resulting generator network to create new, unique images that push the boundaries of what is possible with traditional art techniques.

Surreal image generated by GANs

Surreal image generated by GANs

GANs have also been used to create videos, such as music videos or short films. In addition, artists have trained GANs on datasets of video footage and used the resulting generator network to develop new, unique videos that can be used for artistic or entertainment purposes.

Overall, GANs have given artists and other creative people new opportunities. They can now try out new techniques and make unique, one-of-a-kind works of art that were impossible to make before.

Generative Adversarial networks tutorial in Python

Here’s an example of implementing a simple Generative Adversarial Network (GAN) using TensorFlow.

First, let’s import the necessary libraries:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

Next, let’s define the generator and discriminator models:

# Generator model
def make_generator_model():
    model = keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

# Discriminator model
def make_discriminator_model():
    model = keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Following on, let’s define the loss functions for both the generator and discriminator:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Next, let’s define the optimizer for both the generator and discriminator:

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Now, let’s train the GAN:

# Load the MNIST dataset
(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

# Set the batch size and number of epochs
BUFFER_SIZE = 60000
BATCH_SIZE = 256
EPOCHS = 100

# Create the generator
generator = make_generator_model()

# Create the discriminator
discriminator = make_discriminator_model()

#Define the training loop
@tf.function
def train_step(images):
  # Generate random noise
  noise = tf.random.normal([BATCH_SIZE, 100])
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      # Generate images using the generator
      generated_images = generator(noise, training=True)
  
      # Get the discriminator's predictions for the real and generated images
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
  
      # Calculate the loss for both the generator and discriminator
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

  # Calculate the gradients for the generator and discriminator
  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

# Create a function to generate and save images
def generate_and_save_images(model, epoch, test_input):
  # Generate images from the model
  predictions = model(test_input, training=False)
  # Rescale the pixel values to [0, 1]
  predictions = (predictions + 1) / 2.0
  
  # Create a plot to display the images
  fig = plt.figure(figsize=(4, 4))
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0], cmap='gray')
      plt.axis('off')
  
  # Save the plot
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
  
# Create a function to train the GAN
def train(dataset, epochs):
  # Generate a fixed noise vector to use for visualization
  test_input = tf.random.normal([16, 100])
  for epoch in range(epochs):
    for image_batch in dataset:
        # Train the discriminator
        train_step(image_batch)

    # Generate and save images every 10 epochs
    if (epoch + 1) % 10 == 0:
        generate_and_save_images(generator, epoch + 1, test_input)

    print('Epoch {} completed'.format(epoch + 1))

  # Generate a final set of images and save them
  generate_and_save_images(generator, epochs, test_input)
  
# Load the dataset and create batches
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Train the GAN
train(train_dataset, EPOCHS)

And that’s it! This code should train a simple GAN on the MNIST dataset and save generated images every 10 epochs.

Note that there are many ways to modify this code to improve the performance and stability of the GAN, but this should serve as a good starting point.

Summary code explanation

Here’s a brief overview of what each part of the code is doing:

  1. Define the generator and discriminator: We define the generator and discriminator models using the make_generator_model and make_discriminator_model functions.
  2. Define the training loop: We define a training loop using the @tf.function decorator to improve performance. Inside the training loop, we generate random noise, pass it through the generator to generate fake images, and then pass both real and fake images through the discriminator to get its predictions. We then calculate the loss for both the generator and discriminator and use the tf.GradientTape() context manager to compute gradients. Finally, we apply the gradients to the optimizer using the apply_gradients method.
  3. Define the generate_and_save_images function: This function takes in the generator model, the current epoch number, and a fixed noise vector, and generates a set of images using the generator model. It then rescales the pixel values to [0, 1], creates a plot to display the images, and saves the plot to a file.
  4. Define the train function: This function takes in the training dataset and the number of epochs to train for, and trains the GAN using the training loop defined earlier. It also calls the generate_and_save_images function every 10 epochs to generate and save a set of generated images.
  5. Load the dataset and create batches: We load the MNIST dataset and create batches using the tf.data.Dataset API.
  6. Train the GAN: We call the train function to train the GAN on the MNIST dataset.

Again, note that there are many ways to modify this code to improve the performance and stability of the GAN, but this should serve as a good starting point.

Conclusion

In conclusion, Generative Adversarial Networks (GANs) are a class of deep learning models used to generate synthetic data that mimics real data. GANs consist of two neural networks, a generator and a discriminator, that compete with each other in a game-like setting. The generator produces synthetic data, while the discriminator distinguishes between synthetic and real data. The two networks continue to improve until the generator produces data indistinguishable from the real data.

GANs have been successfully used in various applications, such as image synthesis, text-to-image translation, and music generation. GANs have also improved data quality, such as upscaling low-resolution images or converting black and white images to colour.

However, GANs can be challenging to train and require large amounts of data. They are also susceptible to producing biased results if the training data is biased. Researchers are actively developing new techniques to address these challenges and improve the effectiveness of GANs.

Overall, GANs are a powerful tool for generating synthetic data that can be used for various applications. As research advances, we can expect to see even more exciting and innovative applications of GANs in the future.

About the Author

Neri Van Otten

Neri Van Otten

Neri Van Otten is the founder of Spot Intelligence, a machine learning engineer with over 12 years of experience specialising in Natural Language Processing (NLP) and deep learning innovation. Dedicated to making your projects succeed.

Recent Articles

glove vector example "king" is to "queen" as "man" is to "woman"

Text Representation: A Simple Explanation Of Complex Techniques

What is Text Representation? Text representation refers to how text data is structured and encoded so that machines can process and understand it. Human language is...

wavelet transform: a wave vs a wavelet

Wavelet Transform Made Simple [Foundation, Applications, Advantages]

Introduction to Wavelet Transform What is Signal Processing? Signal processing is critical in various fields, from telecommunications to medical diagnostics and...

ROC curve

Precision And Recall In Machine Learning Made Simple: How To Handle The Trade-off

What is Precision and Recall? When evaluating a classification model's performance, it's crucial to understand its effectiveness at making predictions. Two essential...

Confusion matrix explained

Confusion Matrix: A Beginners Guide & How To Tutorial In Python

What is a Confusion Matrix? A confusion matrix is a fundamental tool used in machine learning and statistics to evaluate the performance of a classification model. At...

ordinary least square is a linear relationship

Understand Ordinary Least Squares: How To Beginner’s Guide [Tutorials In Python, R & Excell]

What is Ordinary Least Squares (OLS)? Ordinary Least Squares (OLS) is a fundamental technique in statistics and econometrics used to estimate the parameters of a linear...

how does METEOR work

METEOR Metric In NLP: How It Works & How To Tutorial In Python

What is the METEOR Score? The METEOR score, which stands for Metric for Evaluation of Translation with Explicit ORdering, is a metric designed to evaluate the text...

glove vector example "king" is to "queen" as "man" is to "woman"

BERTScore – A Powerful NLP Evaluation Metric Explained & How To Tutorial In Python

What is BERTScore? BERTScore is an innovative evaluation metric in natural language processing (NLP) that leverages the power of BERT (Bidirectional Encoder...

Perplexity in NLP explained

Perplexity In NLP: Understand How To Evaluate LLMs [Practical Guide]

Introduction to Perplexity in NLP In the rapidly evolving field of Natural Language Processing (NLP), evaluating the effectiveness of language models is crucial. One of...

BLEU Score In NLP: What Is It & How To Implement In Python

What is the BLEU Score in NLP? BLEU, Bilingual Evaluation Understudy, is a metric used to evaluate the quality of machine-generated text in NLP, most commonly in...

0 Comments

Submit a Comment

Your email address will not be published. Required fields are marked *

nlp trends

2024 NLP Expert Trend Predictions

Get a FREE PDF with expert predictions for 2024. How will natural language processing (NLP) impact businesses? What can we expect from the state-of-the-art models?

Find out this and more by subscribing* to our NLP newsletter.

You have Successfully Subscribed!