当前位置:网站首页>使用Keras构建GAN,以Mnist为例

使用Keras构建GAN,以Mnist为例

2022-08-11 05:35:00 Pr4da

在开始之前请先了解GAN的原理,有很多博主讲的都很好,在这里我就不再过多讲解,视频推荐台大李宏毅老师的课程。

GAN共包含两个主要结构generator和discriminator。generator负责生成假的数据来“欺骗”discriminator,discriminator负责判断输入的数据是否为generator生成的,二者互相迭代,最终实现generator生成能以假乱真的数据。以下以Mnist数据集为例,使用GAN来产生手写数字。

构建网络模型

1.generator

神经网络模型有输出就有输入,我们要想得到假的生成数据,就要给模型一个输入,这里采用形状为[100,]的向量作为输入,输出是形状为[28,28,1]的矩阵。

    def build_generator(self):
        # input shape = [100,]
        # output shape = [np.prod(self.img_shape)]
        
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        # image_shape = [28,28,1]
        model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()计算形状乘积
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

2.discriminator

判别器的输入为生成的假的图片,形状为[28,28,1],输出为判别器给出的validity,区间为[0,1],数越大表面判别器任务输入是真实数据的可能性越大,反之则认为输入数据是真实数据的可能性越小。

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

3.构建完整模型

        optimizer = Adam(0.0002, 0.5)

        # 构建和编译判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        # 构建生成器
        self.generator = self.build_generator()

        # 输入噪声给生成器,并产生假的图片
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # 冻结判别器
        self.discriminator.trainable = False

        # 将假的图片输入给判别器
        validity = self.discriminator(img)

        # 将生成器和判别器合二为一
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

4.训练策略

  1. 先训练判别器,将真实图片和生成器生成的假的图片(真实图片标签为1,生成图片标签为0)分别输入到generator中,计算两个数据集损失的平均值

    2.然后训练生成器,但实际上训练的是刚刚构建的完整的模型combined,但是由于将discriminator冻结了,所以训练的是generator。然后将预测结果与1对比,如果越接近1说明生成器已经生成了能欺骗discriminator的图片
    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data() # 分别是训练集数据,训练集标签,测试集数据,测试集标签 (tuple格式)
        # X_train.shape = (60000, 28, 28)
        
        # Rescale -1 to 1 归一化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3) # 增加一维 ---> (60000,28,28,1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            # Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size) #产生0到60000,batchsize个随机整数
            imgs = X_train[idx] # 随机取出batchsize个图片

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #随机产生输入,输入形状(batch_size, 100)

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)  # 输入的是真实图片,valid都是1
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 输入的都是产生的图片,fake都是0
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            # Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

5.GAN网络结构

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
=================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_5 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_6 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_7 (Dense)              (None, 784)               803600    
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584

完整代码

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    # 构建生成器
    def build_generator(self):
        # input shape = [100,]
        # output shape = [np.prod(self.img_shape)]
        
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        # image_shape = [28,28,1]
        model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()计算形状乘积
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)
    
    # 构建判别器
    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data() # 分别是训练集数据,训练集标签,测试集数据,测试集标签 (tuple格式)
        # X_train.shape = (60000, 28, 28)
        
        # Rescale -1 to 1 归一化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3) # 增加一维 ---> (60000,28,28,1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            # Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size) #产生0到60000,batchsize个随机整数
            imgs = X_train[idx] # 随机取出batchsize个图片

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #随机产生输入,输入形状(batch_size, 100)

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)  # 输入的是真实图片,valid都是1
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 输入的都是产生的图片,fake都是0
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            # Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=200) #sample_interval => 采样间隔

下面为分别训练第0,10000,20000和29800个epoch时generator产生的图像:
0.png

10000.png

20000.png

29800.png
参考:
好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体

原网站

版权声明
本文为[Pr4da]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_40210586/article/details/114975095