| by msbeta | No comments

深度卷积生成对抗网络(DCGAN)

1.GAN的基本原理

GAN的基本原理其实非常简单,它包含两个网络,G网络(Generator)和D网络(Discriminator)。G网络的目标是尽量生成真实的图片去欺骗判别网络D,D网络的目标是尽量把G网络生成的图片和真实的图片分别开来。

最理想的结束状态是,G网络可以生成足以“以假乱真”的图片,而D网络,它难以判定G生成的图片究竟是不是真实的。

先看以下枯燥的数学语言描述下GAN的核心原理:

上述公式中:x表示真实图片,z表示输入G网络的随机噪声,而G(z)表示G网络生成的图片;D(x)表示D网络判断真实图片是否真实的概率(因为x就是真实的,所以对于D来说,这个值越接近1越好)。

D(G(z))是D网络判断G生成的图片是否真实的概率,G应该希望自己生成的图片“越接近真实越好”,也就是说,G希望D(G(z))尽可能得大,这时V(D, G)会变小,因此对于G来说就是求最小的G(min_G)。

D的能力越强,D(x)应该越大,D(G(x))应该越小。这时V(D,G)会变大。因此,对于D来说是求最大D(max_D)。

下面实现一个DCGAN生成二次元图像的例子。先在我的渣渣笔记本上的训练效果。

笔记本训练比较慢,所以只用了1000张图片作为训练输入数据,训练了50个epoch,不过可以看出已经有初步的效果了。

2. GAN二次元头像数据集

Tensorflow的官网Demo中使用的MNIST数据集,这里我们换一个数据集kaggle——Anime Faces,里面有21551张动漫头像的图片。数据链接如下:

kaggle——Anime Faces

部分图片如下:

2.1 加载数据集(Dataset)

引入必要的Python头文件。

import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

加载数据,构造Tensorflow数据集,同时将图片的像素数值缩放到[-1, 1]之间。

def load_data():
    all_images = []

    # max_count = 1000
    # count = 0

    for dirname, _, filenames in os.walk('/GAN/archive/data/'):
        for filename in filenames:

            image =  imageio.imread(os.path.join(dirname, filename))

            all_images.append(image)

            # count = count + 1
            # if count > max_count:
            #    break

    all_images = np.array(all_images)
    all_images = (all_images - 127.5) / 127.5

    return all_images


train_images = load_data()

BUFFER_SIZE = 3000
BATCH_SIZE = 10

# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

先看下数据集(Dataset)中数据。

3. 定义模型(Model)

3.1 Generator Model

Generator使用tf.keras.layers.Conv2DTran spose进行上采样(upsampling)将随机噪声(Random Noise)生成64x64x3的图像数据。

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(4 * 4 * 1024, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((4, 4, 1024)))
    assert model.output_shape == (None, 4, 4, 1024) # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 8, 8, 512)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 16, 16, 256)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 32, 32, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 64, 64, 3)

    return model

使用未经训练的Generator模型生成一张图像看看效果:

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
pred_img = (generated_image[0, :, :, :] + 1.0) / 2.0
plt.imshow(pred_img)
plt.axis('off')
plt.show()

未训练的Generator生成图像如下:

3.2 Discriminator Model

Discriminator是一个基于CNN的分类器(classifier)。

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[64, 64, 3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(1024, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

我们使用未经过训练的discriminator来判断生成的图像真假。

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)

输出:

tf.Tensor([[0.00011664]], shape=(1, 1), dtype=float32)

4. 定义Loss函数和优化器

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

4.1 Discriminator loss

Discriminator Loss用来衡量discriminator能够区分图像真假的能力。

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

4.2 Generator Loss

Generator Loss用来衡量Generator欺骗discriminator的能力。

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Discriminator和Generator由于两个不同独立网络,所以定义了两个不同的Optimizer。

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

5. 定义Train Loop

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])

Training Loop中Generator将随机数(random seed)生成图像,Discriminator用来区分生成图像的真假。

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

6. 保存和生成图像

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

7.训练模型

调用train()方法同时训练generator和discriminator。训练开始时,generator生成的图片看起来像是随机噪声,随着训练的进行,生成的图像越来越真实。

train(train_dataset, EPOCHS)
epoch = 0
epoch=10
epoch=20
epoch=30
epoch=40
epoch=50

8. 生成GIF图片

最后把训练过程中保存的图片合并起来,生成一副gif图片,这样可以直接的看GAN网络的训练过程。

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

效果如下:

参考材料

1.https://www.tensorflow.org/tutorials/generative/dcgan?hl=zh-cn

2.https://zhuanlan.zhihu.com/p/24767059

3.https://my.oschina.net/u/4264470/blog/4422434

发表评论