YAN's BLOG

dcgan

2019-01-06

Licensed under the Apache License, Version 2.0 (the “License”).

DCGAN: An example with tf.keras and eager



Run in Google Colab


View source on GitHub

This notebook demonstrates how to generate images of handwritten digits using tf.keras and eager execution. To do so, we use Deep Convolutional Generative Adverserial Networks (DCGAN).

This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.

Below is the output generated after training the generator and discriminator models for 150 epochs.

sample output

1
2
# to generate gifs
!pip install imageio
Requirement already satisfied: imageio in /home/dongnanzhy/miniconda3/lib/python3.6/site-packages (2.1.2)
Requirement already satisfied: numpy in /home/dongnanzhy/miniconda3/lib/python3.6/site-packages (from imageio) (1.14.2)
Requirement already satisfied: pillow in /home/dongnanzhy/miniconda3/lib/python3.6/site-packages (from imageio) (5.0.0)
You are using pip version 18.0, however version 18.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.

Import TensorFlow and enable eager execution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from __future__ import absolute_import, division, print_function

# Import TensorFlow >= 1.10 and enable eager execution
import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from IPython import display
/home/dongnanzhy/miniconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Load the dataset

We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits.

1
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
1
2
3
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# We are normalizing the images to the range of [-1, 1]
train_images = (train_images - 127.5) / 127.5
1
2
BUFFER_SIZE = 60000
BATCH_SIZE = 256

Use tf.data to create batches and shuffle the dataset

1
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Write the generator and discriminator models

  • Generator

    • It is responsible for creating convincing images that are good enough to fool the discriminator.
    • It consists of Conv2DTranspose (Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size (mnist image size) which is (28, 28, 1).
    • We use leaky relu activation except for the last layer which uses tanh activation.
      (这里代码用的是relu而不是leaky relu)
  • Discriminator

    • The discriminator is responsible for classifying the fake images from the real images.
    • In other words, the discriminator is given generated images (from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake (generated) and real (MNIST images).
    • Basically the generator should be good enough to fool the discriminator that the generated images are real.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Generator(tf.keras.Model):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)
self.batchnorm1 = tf.keras.layers.BatchNormalization()

self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)
self.batchnorm2 = tf.keras.layers.BatchNormalization()

self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.batchnorm3 = tf.keras.layers.BatchNormalization()

self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)

def call(self, x, training=True):
# input is a vector of shape [?, 100(noise_dim)] --> [?, 7*7*64]
x = self.fc1(x)
x = self.batchnorm1(x, training=training)
x = tf.nn.relu(x)

# [?, 7*7*64] --> [?, 7, 7, 64]
x = tf.reshape(x, shape=(-1, 7, 7, 64))

# [?, 7, 7, 64] --> [?, 7, 7, 64]
x = self.conv1(x)
x = self.batchnorm2(x, training=training)
x = tf.nn.relu(x)

# [?, 7, 7, 64] --> [?, 14, 14, 32]
x = self.conv2(x)
x = self.batchnorm3(x, training=training)
x = tf.nn.relu(x)

# [?, 14, 14, 32] --> [?, 28, 28, 1], output range (-1, 1)
x = tf.nn.tanh(self.conv3(x))
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Discriminator(tf.keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')
self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')
self.dropout = tf.keras.layers.Dropout(0.3)
self.flatten = tf.keras.layers.Flatten()
self.fc1 = tf.keras.layers.Dense(1)

def call(self, x, training=True):
# [?, 28, 28, 1] --> [?, 14, 14, 64]
x = tf.nn.leaky_relu(self.conv1(x))
x = self.dropout(x, training=training)
# [?, 14, 14, 64] --> [?, 7, 7, 128]
x = tf.nn.leaky_relu(self.conv2(x))
x = self.dropout(x, training=training)
x = self.flatten(x)
x = self.fc1(x)
return x
1
2
generator = Generator()
discriminator = Discriminator()
1
2
3
4
# Defun gives 10 secs/epoch performance boost
# 用来提到运行速度,python function to calllable tf graph
generator.call = tf.contrib.eager.defun(generator.call)
discriminator.call = tf.contrib.eager.defun(discriminator.call)

Define the loss functions and the optimizer

  • Discriminator loss (这里其实就是classification的loss,但做了点变化)

    • The discriminator loss function takes 2 inputs; real images, generated images
    • real_loss is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images)
    • generated_loss is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images)
    • Then the total_loss is the sum of real_loss and the generated_loss
  • Generator loss (generator的loss与discriminator中generated_image_loss相反)

    • It is a sigmoid cross entropy loss of the generated images and an array of ones
  • The discriminator and the generator optimizers are different since we will train them separately.
1
2
3
4
5
6
7
8
9
10
11
def discriminator_loss(real_output, generated_output):
# [1,1,...,1] with real output since it is true and we want
# our generated examples to look like it
real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)

# [0,0,...,0] with generated images since they are fake
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)

total_loss = real_loss + generated_loss

return total_loss
1
2
def generator_loss(generated_output):
return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)
1
2
discriminator_optimizer = tf.train.AdamOptimizer(1e-4)
generator_optimizer = tf.train.AdamOptimizer(1e-4)

Checkpoints (Object-based saving)

1
2
3
4
5
6
checkpoint_dir = './data/dcgan/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)

Training

  • We start by iterating over the dataset
  • The generator is given noise as an input which when passed through the generator model will output a image looking like a handwritten digit
  • The discriminator is given the real MNIST images as well as the generated images (from the generator).
  • Next, we calculate the generator and the discriminator loss.
  • Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.

Generate Images

  • After training, its time to generate some images!
  • We start by creating noise array as an input to the generator
  • The generator will then convert the noise into handwritten images.
  • Last step is to plot the predictions and voila!
1
2
3
4
5
6
7
8
9
EPOCHS = 150
noise_dim = 100
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement of the gan.
# noise 是正太分布的100维噪声
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
noise_dim])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def generate_and_save_images(model, epoch, test_input):
# make sure the training parameter is set to False because we
# don't want to train the batchnorm layer when doing inference.
predictions = model(test_input, training=False)

fig = plt.figure(figsize=(4,4))

for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
# generated image的shape [28,28,1], range 是(-1, 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')

plt.savefig('./data/dcgan/image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def train(dataset, epochs, noise_dim):  
for epoch in range(epochs):
start = time.time()

# dataset 是 tf.data.Dataset, iterable
for images in dataset:
# generating noise from a normal distribution
noise = tf.random_normal([BATCH_SIZE, noise_dim])

# gradientTape是eager execution里自动算gradient的
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# 通过noise生成的image
generated_images = generator(noise, training=True)
# 从dataset中的真实image 和 noise生成image的sigmoid output
real_output = discriminator(images, training=True)
generated_output = discriminator(generated_images, training=True)
# generator loss: generated image被分类成1则loss为0
gen_loss = generator_loss(generated_output)
# discriminator loss: generated image被分类成0则loss为0,real image被分类成1则loss为0
disc_loss = discriminator_loss(real_output, generated_output)

gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)

# 注意:generator和discriminator是分开train的
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))


if epoch % 1 == 0:
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
random_vector_for_generation) # 测试用的noise是global的

# saving (checkpoint) the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)

print ('Time taken for epoch {} is {} sec'.format(epoch + 1,
time.time()-start))
# generating after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
random_vector_for_generation)
1
train(train_dataset, EPOCHS, noise_dim)

png

Restore the latest checkpoint

1
2
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus at 0x7f23695fb668>

Display an image using the epoch number

1
2
def display_image(epoch_no):
return PIL.Image.open('./data/dcgan/image_at_epoch_{:04d}.png'.format(epoch_no))
1
display_image(EPOCHS)

png

Generate a GIF of all the saved images.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
with imageio.get_writer('./data/dcgan/dcgan.gif', mode='I') as writer:
filenames = glob.glob('./data/dcgan/image*.png')
filenames = sorted(filenames)
last = -1
for i,filename in enumerate(filenames):
frame = 2*(i**0.5)
if round(frame) > round(last):
last = frame
else:
continue
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)

# this is a hack to display the gif inside the notebook
os.system('cp ./data/dcgan/dcgan.gif ./data/dcgan/dcgan.gif.png')
0
1
display.Image(filename="./data/dcgan/dcgan.gif.png")

png

To downlod the animation from Colab uncomment the code below:

1
2
#from google.colab import files
#files.download('dcgan.gif')