Introduction to Generative Adversarial Networks with Code- Part 2

0
31
GAN with Code

“Hello and a warm welcome to this blog where all the action is about to happen! 

Two Neural Networks are going to face off each other and only one of them will be left standing!

The only question is, which one are you rooting for? 

The Generator or the Discriminator?

We’re sorry if we’re hyping up this thing a little. 

Okay fine, we’re totally overdoing it. 

But we promise you’ll share our enthusiasm too by the end of this blog. 

Before we begin, for the readers that have stumbled upon this blog first, we’d like you all to know that we’ve done a conceptual introduction of GANs in our previous blog, Introduction to Generative Adversarial Networks- Part 1. We highly recommend you to go through that blog because the code in this one requires you to have a thorough conceptual understanding of the entire idea behind GANs which we’ve covered with a jargon-free analogy. It neither assumes nor requires any advanced previous AI knowledge. That’s the way we always keep it at Eduonix. 

Right then. 

It’s about time. 

GANs IN ACTION! 

So the problem statement we have to generate handwritten digits like those found in the MNIST dataset. 

Here, the goal of our discriminator, when shown an instance from the true MNIST dataset, would be to recognize those that are authentic. 

Meanwhile, our generator is creating new, synthetic images that it passes to the discriminator. This would be you with the canvas trying to replicate the painting in our analogy. The generator does this in the hopes that the generated images, too, will be deemed authentic by the discriminator (our art expert), even though they are fake. The goal of the generator is to generate plausible handwritten digits, which is to lie without being caught. The goal of the discriminator is to identify images coming from the generator as fake. 

In Situation B, your first attempt at the painting wouldn’t be too accurate. In fact, it would hardly represent anything like the original painting. The analogy for this in GANs is that we start out the generator by passing it random noise which eventually learns to mold into plausible handwritten digits images. 

We illustrate the scenario below: 

Schematic Representation of the GAN
Figure 1. Schematic Representation of the Generative Adversarial Network for our use case

For our use case, the discriminator network is a standard convolutional network that can categorize the images fed to it, a binomial classifier labeling images as real or fake. The generator is an inverse convolutional network, in the sense that while a standard convolutional classifier takes an image and downsamples it to produce a probability, the generator takes a vector of random noise and upsamples it to an image. The first throws away data through downsampling techniques like MaxPooling, and the second generates new data.

Both nets are trying to optimize a different and opposing objective function, or loss function. 

This is essentially our student and art expert model. As the discriminator changes it’s behavior, looking for more and more finer details of the image produced, so does the generator by generating more and more plausible images and vice versa. Their losses push against each other.

Now that the groundwork is laid, we’ll move on to some code. But before that, here’s the final structure of the GAN that we’ll be training: 

Overall Structure of GAN
Figure 2. The overall structure of the GAN we’ll be training

Right! Let’s move on to some code then. 

I.  Importing Libraries

II.  Loading the MNIST Dataset 

Running the above code will result in the following output:

Output of Running MNIST Dataset

In case, the readers are unaware of the MNIST dataset, we recommend you check out our earlier blogs out on the Deep Learning with Keras Series where we take up several Neural Network architectures and try to solve various problems using Deep Learning.

Read More:

For completeness of this blog, however, we show the MNIST dataset below: 

MNIST Handwritten Digits Dataset
Figure 3. MNIST Handwritten Digits Dataset

We will use the images in the training dataset as the basis for training a Generative Adversarial Network. 

Specifically, the generator model will learn how to generate new plausible handwritten digits between 0 and 9, using a discriminator that will try to distinguish between real images from the MNIST training dataset and new images output by the generator model.

III.  Defining the Discriminator

The role of the discriminator should be now clear to our reader. The discriminator model must take a sample image from our MNIST dataset as input and output a classification prediction as to whether the sample is real or fake. 

This is thus, a binary classification problem with input as a 28×28 image and output as the likelihood of whether the sample is real or fake. 

As described above, we’ll be using a Convolutional Neural Network for our discriminator. The objective function (or the loss) that’ll aim to minimize will be binary_cross_entropy function appropriate for binary classifications. The choice of our optimizer, by default, will be Adam. 

We’ve covered some best practices when defining various Neural Network architectures to model various problems in a blog compiled just for you in the Deeper with Deep Neural Networks, be sure to check it out. 

The function define_discriminator() below defines the discriminator model: 

To call define the model now, all we have to do is call the define_discriminator() function in our code. We always recommend to print out the summary of the model you’ve created as a good practice along with the plot of the model so that it’s easy for you to spot any error you might have made during model definition. 

The plot of the model that we expect is shown below: 

Discriminator Model
Figure 4. Discriminator Model

If you reckon, the discriminator has to discriminate the real images from the fake ones. Currently, all we have are the real images that are in the trainX variable. We do not have any fake images. So we’ll have to go about sorting that part out.

First, we’ll need to load and prepare the dataset of real images and then we’ll worry about the fake samples. 

It should be noted that the images are 2D arrays of pixels and Convolutional Neural Networks expect 3D arrays of images as input, where each image has one or more channels. 

Consequently, we must update the images to have an additional dimension for the grayscale channel. We can do this using the expand_dims() function provided by NumPy and specify the final dimension for the channels-last image format. 

Now for training, the model will be updated in batches, specifically with a collection of real samples and a collection of generated samples. 

So we write another function called the generate_real_samples() function below which will take the training dataset as an argument and will select a random subsample of images; it will also return class labels for the sample, specifically a class label of 1, to indicate real images.

There. We’ve sorted out the real samples. Now for the fake samples. 

We don’t have a generator model yet, so instead, we can generate images comprised of random pixel values, specifically random pixel values in the range [0,1] like our scaled real images. We’ve defined the generate_fake_samples() function below to generate fake images. 

Now, we’ll need to train our Discriminative Model.  

IV.   Training the Discriminative Model. 

Even though the function might look complicated to you. Here’s what it’s essentially doing, we are using a batch size of 256 images where 128 are real and 128 are fake each iteration. We then update the discriminator separately for real and fake examples so that we can calculate the accuracy of the model on each sample prior to the update. This gives insight into how the discriminator model is performing over time.

You’ll see that in about 40 iterations, it attains peak performance, easily being able to differentiate between fake and the real samples. 

Discriminator Performance
Figure 5. Discriminator Performance

Now that we know how to define and train the discriminator model, we need to look at developing the generator model.

V.   Defining the Generator

The next step is defining the generator. The generator will start out by outputting random noise at first, analogous to your first crude replication of the painting, and will then eventually learn to mold the randomness into seemingly plausible handwritten digits to pass the discriminator, analogous to your final attempts at getting the painting right to the details. 

The question here is perhaps, how do we go from random noise to a plausible image? The idea is simple. Like in the case of the discriminator network where we downsampled the image dimensions to obtain a feature map with which the network took the decision of the likelihood of fakeness and realness, in the generator, we go the other way around: That is, we start with something small, and then upsample the image dimensions to give the network space to invent and create. This is achieved by transposed convolutions. 

We’ve included the code to summarise and plot the model just for the reader’s convenience and we highly recommend checking the plots and the summary out. 

The plot of the generator model is obtained below: 

Generator Model
Figure 6. Generator Model

An explanation is in order for the variable, “latent_dim”.

This variable represents the colors and brushes available to you drawing a parallel between GANs and our concept explanation (Situation B). It’s only using these colors and brushes that you’ll be able to draw your painting. Of course, how you learn to use the brushes and the colors, is on you. 

This is exactly the case here. That variable represents a distribution of points that the generator can draw from while attempting to generate plausible images. This distribution is essentially random and is also called the latent space or the latent dimension. 

We’ll now show you a convenient function to generate this latent dimension or rather, points from the latest dimension. 

As a matter of fact, this model cannot do much at the moment. We can give it some latent points with which it can generate an image but this image will not nearly represent a handwritten digit since we’ve not oriented the generator towards the handwritten digit generation task. 

A question that a reader may ask nevertheless is what happens when we give the latent input to the generator just for the sake of it? What does it ‘generate’? 

We’ve shown you exactly what the model would generate. 

As expected, the output is essentially random and does not represent a handwritten digit image. This is of course because we haven’t really ‘trained’ the generator to output plausible images. 

We’ll use a function from the previous discriminator section to demonstrate the output of the generator to you. 

All we have to do is call the predict() method on our model and give it 

 

Output of an untrained generator model
Figure 7. The output of an untrained generator model

Now, of course, we want to orient the generator towards generating plausible images which we’ll achieve by training it. It is described in the subsequent section. 

VI.   Training the Generator Model 

Think about Situation B in which you are, essentially, the generator model. How does the process look like? 

Let’s go step by step. 

  1. You first output some really off drawings.
  2. You show the paintings for the expert for his opinion.
  3. The expert rejects the paintings. 
  4. You go back trying to generate better painting in the hopes to satisfy the expert. 

Now let’s take this exact same concept and apply it into GANs. 

We have two models now, the generator and the discriminator. 

The generator will take input as latent points and generate some image, which the discriminator should asses as fake or real. And based on this feedback by the discriminator, the generator should do better the next time. 

So what we’re basically doing is we’re stacking the generator and discriminator such that the generator receives as input random points in the latent space and generates samples that are fed into the discriminator model directly, classified, and the output of this larger model can be used to update the model weights of the generator. 

This, dear readers, is the Generative Adversarial Network. 

One important note

When training the generator via this GAN model, there are two important changes. 

First, we do not change the weights of the discriminator. Since discriminator has already been trained on. 

Secondly, we want the discriminator to think that the samples outputted by the generator are real, and not fake. 

Now, why would we want to do this? If the discriminator thinks the samples are real from the start, what’s the point? 

The answer is simple. 

Let us imagine the generator generated a random gibberish image like in the case of the untrained generator model. Now we mark this deliberately as a real sample. 

We can imagine that the discriminator will then classify the generated sample as not real (class 0) or a low probability of being real (0.2 or 0.3) since the discriminator is already trained to detect fake and real samples and our inputted image is surely fake. 

Now the backpropagation process used to update the model weights will see this as a large error and will update the model weights (i.e. only the weights in the generator since the discriminator weights are fixed) to correct for this error, in turn making the generator better at generating good fake samples.

Therefore, when the generator is trained as part of the GAN model, we will mark the generated samples as real (class 1).

We’ll now define our GAN in accordance with the discussion above: 

The plot of our GAN is shown below: 

GAN Model
Figure 8. GAN Model

For GAN training, we’ve defined a function to proceed in the fashion as discussed. We’ve included all the comments necessary for the reader to really understand what is happening with each line. 

This concludes the training part of our GAN. 

There it is guys! That’s the showdown we’ve all been waiting for! The generator will try to synthesize images to ‘bypass’ or ‘fool’ the discriminator and the discriminator will try to foil the attempts of the generator, indirectly causing the generator to step his game up and create better fakes to finally bypass the discriminator. 

Does it succeed in our case?

We’ll find out. 

On running the code, we get the following output: 

Training the GAN
Figure 9. Training the GAN

The only thing remaining now is evaluation. 

VII.   Evaluation

Generally, there are no objective ways to evaluate the performance of a GAN model.

We cannot calculate this objective error score for generated images. 

Instead, images must be subjectively evaluated for quality by a human operator. This means that we cannot know when to stop training without looking at examples of generated images. In turn, the adversarial nature of the training process means that the generator is changing after every batch, meaning that once “good enough” images can be generated, the subjective quality of the images may then begin to vary, improve, or even degrade with subsequent updates. 

There are three ways to handle this complex training situation.  

  1. Periodically evaluate the classification accuracy of the discriminator on real and fake images.
  2. Periodically generate many images and save them to file for subjective review. 
  3. Periodically save the generator model.

All three of these actions can be performed at the same time for a given training epoch, such as every five or 10 training epochs. The result will be a saved generator model for which we have a way of subjectively assessing the quality of its output and objectively knowing how well the discriminator was fooled at the time the model was saved. Training the GAN over many epochs, such as hundreds or thousands of epochs, will result in many snapshots of the model that can be inspected and from which specific outputs and models can be cherry-picked for later use.

First, we can define a function called summarize_performance() function that will summarize the performance of the discriminator model. It does this by retrieving a sample of real MNIST images, as well as generating the same number of fake MNIST images with the generator model, then evaluating the classification accuracy of the discriminator model on each sample and reporting these scores. Next, we can update the summarize_performance() function to both save the model and to create and save a plot generated examples. As we are evaluating the discriminator on 100 generated MNIST images, we can plot all 100 images as a 10 by 10 grid. 

We put this function we’ve written in the def_train() function for evaluation as follows: 

In the last 3 lines of the above code, we’ve inserted the summarize_performance method.

VIII.   Conclusions & Results 

On running the entire code as a whole, you should see that the GAN begins to train. The model performance is reported every batch, including the loss of both the discriminative (d) and generative (g) models. The generator is evaluated every 20 epochs, resulting in 10 evaluations, 10 plots of generated images, and 10 saved models. 

Plot of 100 GAN Generated MNIST Figures After 10 Epochs
Figure 10. The plot of 100 GAN Generated MNIST Figures After 10 Epochs

For our run, the results after 10 epochs were found to be low quality, although we can see that the generator has learned to generate centered figures in white on a back background

After 20 or 30 more epochs, the model begins to generate very plausible MNIST figures, 

Plot of 100 GAN Generated MNIST Figures After 40 Epochs
Figure 11. The plot of 100 GAN Generated MNIST Figures After 40 Epochs

The generated images after 100 epochs are not greatly different, but we can detect less blocky-ness in the curves.

Plot of 100 GAN Generated MNIST Figures After 100 Epochs
Figure 12. The plot of 100 GAN Generated MNIST Figures After 100 Epochs

It should be noted that the performance of a GAN may deteriorate over the epochs. It is not necessary that the performance obtained at the 100th epoch will be better or even as good as the performance obtained at the 40th epoch. Therefore, it is always a good practice to use the functions provided in this blog whenever training a GAN so you can visualize what actually is the best performance for your particular application. 

Once a final generator model is selected, it can be used in a standalone manner for your application. 

This involves first loading the model from a file, then using it to generate images. The generation of each image requires a point in the latent space as input. 

The complete example of loading the saved model and generating images is listed below. In this case, we will use the model saved after 100 training epochs, but the model saved after 40 or 50 epochs would work just as well.

Running the example first loads the model samples 25 random points in the latent space, generates 25 images, then plots the results as a single image.

We can see that most of the images are plausible, or plausible pieces of handwritten digits.

Example of 25 GAN Generated MNIST Handwritten Images
Figure 13. Example of 25 GAN Generated MNIST Handwritten Images

This concludes the complete code example for GANs. In the subsequent blogs, we’ll take up the more advanced variations in GANs that have come up in the literature over the past few years. Specifically, we’ll be looking at the pix2pix GAN used for image translation tasks in the next blog. 

Be sure to tune in you do not want to miss this! Not sure if the showdown lived up to its expectation but we’re sure it enticed the readers. GANs are, after all, the bleeding edge in AI. 

In case, if you want to master deep learning and neural networks, then you can check out this latest tutorialDeep Learning & Neural Networks Using Python & Keras For Dummies by Eduonix that will teach you various aspects of deep learning in detail with its exhaustive 81 lectures.

References

[1] Generative Adversarial Networks, 2014

Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio 

[2] Deep Boltzmann Machines, 2013

Ruslan Salakhutdinov, Geoffrey Hinton

[3] An Introduction to Restricted Boltzmann Machines, 2012

Asja FischerChristian Igel

[4]How to Develop a GAN for Generating MNIST Handwritten Digits

https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-an-mnist-handwritten-digits-from-scratch-in-keras/

Jason Brownlee

[5] Improved Techniques for Training GANs

https://arxiv.org/abs/1606.03498?source=post_page—–b665bbae3317———————- 

LEAVE A REPLY

Please enter your comment!
Please enter your name here