top of page

Generative Adversarial Networks

A GAN is machine learning framework in which two neural networks compete with each other in a zero-sum game in order to learn to generate data

Example Project

I created a GAN to generate data based on the canonical MNIST handwritten digit dataset 

The MNIST handwritten digit dataset is a canonical dataset in the field of computer vision. It features tens of thousands of 28x28 pixel grayscale images of handwritten digits. While a network like a CNN may allow us to classify these digits into classes based on value, a GAN is a technique that learns to generate new data.

For this project, I built and trained a GAN to do just this. The salient features / implementation steps are listed below, along with a video demonstrating samples of generated digits during the training process.

The Game

A GAN has two major components in a zero-sum game - the Discriminator and the Generator. The Generator is the component which generates new data which is to be statistically identical to the input data (handwritten digits)  It is the job of the Discriminator to differentiate between real data (digits from the MNIST dataset), and fake data (digits from the Generator). Is is the job of the Generator to learn to generate realistic digits which can fool the Discriminator.

Creating the Discriminator

The Discriminator is simply a traditional CNN. While most CNNs trained on the MNIST set seek to classify the digits by value, the Discriminator seeks to classify the digits into a real category, and randomly generated noise that has been shuffled into the MNIST data as fake. Below we can see examples of real MNIST data on the left, and fake examples on the right.

The Discriminator is a simple sequence of convolutional layers with leaky ReLU activation (40% dropout), the output of which is flattened and fed into a node with sigmoid activation for binary classification. Binary crossentropy is naturally used for loss.

The Generator

The Generator is a more complicated component to implement. It effectively works in the reverse way a CNN does. Rather than processing images to learn feature maps, it takes (initially meaningless) points in a feature space, and generates images. With feedback from the Discriminator, the Generator can learn to map meaning to this feature space, which effectively represents a space of compressed MNIST images.

We start with a dense layer that has many, many nodes. These nodes act as many compressed versions of low-resolution MNIST images in parallel.  The compressed representations are then upsampled and fed into convolutional layers (which actually act to "deconvolve" the "learned" representation). This process is repeated until a 28x28 image is achieved.


We first train the Discriminator independently to differentiate between real and fake digits. Then, we sample random points in the feature space, input these to the Generator to create MNIST lookalikes, and feed these lookalikes into the Discriminator. We tell the Discriminator that these lookalikes are real, creating a large error when the images do not look real. This error is then used to update the Generator.

You can see the evolution of the generated images when trained on 5,000 MNIST images for 50 epochs below:

bottom of page