当前位置:网站首页>使用Keras构建GAN,以Mnist为例
使用Keras构建GAN,以Mnist为例
2022-08-11 05:35:00 【Pr4da】
在开始之前请先了解GAN的原理,有很多博主讲的都很好,在这里我就不再过多讲解,视频推荐台大李宏毅老师的课程。
GAN共包含两个主要结构generator和discriminator。generator负责生成假的数据来“欺骗”discriminator,discriminator负责判断输入的数据是否为generator生成的,二者互相迭代,最终实现generator生成能以假乱真的数据。以下以Mnist数据集为例,使用GAN来产生手写数字。
构建网络模型
1.generator
神经网络模型有输出就有输入,我们要想得到假的生成数据,就要给模型一个输入,这里采用形状为[100,]的向量作为输入,输出是形状为[28,28,1]的矩阵。
def build_generator(self):
# input shape = [100,]
# output shape = [np.prod(self.img_shape)]
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
# image_shape = [28,28,1]
model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()计算形状乘积
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
2.discriminator
判别器的输入为生成的假的图片,形状为[28,28,1],输出为判别器给出的validity,区间为[0,1],数越大表面判别器任务输入是真实数据的可能性越大,反之则认为输入数据是真实数据的可能性越小。
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
3.构建完整模型
optimizer = Adam(0.0002, 0.5)
# 构建和编译判别器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# 构建生成器
self.generator = self.build_generator()
# 输入噪声给生成器,并产生假的图片
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# 冻结判别器
self.discriminator.trainable = False
# 将假的图片输入给判别器
validity = self.discriminator(img)
# 将生成器和判别器合二为一
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
4.训练策略
- 先训练判别器,将真实图片和生成器生成的假的图片(真实图片标签为1,生成图片标签为0)分别输入到generator中,计算两个数据集损失的平均值
2.然后训练生成器,但实际上训练的是刚刚构建的完整的模型combined,但是由于将discriminator冻结了,所以训练的是generator。然后将预测结果与1对比,如果越接近1说明生成器已经生成了能欺骗discriminator的图片
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data() # 分别是训练集数据,训练集标签,测试集数据,测试集标签 (tuple格式)
# X_train.shape = (60000, 28, 28)
# Rescale -1 to 1 归一化
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3) # 增加一维 ---> (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size) #产生0到60000,batchsize个随机整数
imgs = X_train[idx] # 随机取出batchsize个图片
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #随机产生输入,输入形状(batch_size, 100)
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid) # 输入的是真实图片,valid都是1
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 输入的都是产生的图片,fake都是0
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
5.GAN网络结构
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_1 (Flatten) (None, 784) 0
_________________________________________________________________
dense_1 (Dense) (None, 512) 401920
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 256) 131328
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 256) 0
_________________________________________________________________
dense_3 (Dense) (None, 1) 257
=================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 256) 25856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 256) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256) 1024
_________________________________________________________________
dense_5 (Dense) (None, 512) 131584
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 512) 0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512) 2048
_________________________________________________________________
dense_6 (Dense) (None, 1024) 525312
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024) 4096
_________________________________________________________________
dense_7 (Dense) (None, 784) 803600
_________________________________________________________________
reshape_1 (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
完整代码
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
class GAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
# 构建生成器
def build_generator(self):
# input shape = [100,]
# output shape = [np.prod(self.img_shape)]
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
# image_shape = [28,28,1]
model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()计算形状乘积
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
# 构建判别器
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data() # 分别是训练集数据,训练集标签,测试集数据,测试集标签 (tuple格式)
# X_train.shape = (60000, 28, 28)
# Rescale -1 to 1 归一化
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3) # 增加一维 ---> (60000,28,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size) #产生0到60000,batchsize个随机整数
imgs = X_train[idx] # 随机取出batchsize个图片
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #随机产生输入,输入形状(batch_size, 100)
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid) # 输入的是真实图片,valid都是1
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 输入的都是产生的图片,fake都是0
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
if not os.path.exists("./images"):
os.makedirs("./images")
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200) #sample_interval => 采样间隔
下面为分别训练第0,10000,20000和29800个epoch时generator产生的图像:
边栏推荐
- ETCD cluster fault emergency recovery - local data is available
- iptables nat
- FusionCompute8.0.0 实验(2)虚拟机创建
- Record a Makefile just written
- OA项目之我的会议(会议排座&送审)
- CLUSTER DAY01(集群及LVS简介 、 LVS-NAT集群 、 LVS-DR集群)
- (2) Software Testing Theory (*Key Use Case Method Writing)
- Map Reduce
- xx is not recognized as internal or external command
- ovnif摄像头修改ip
猜你喜欢
HCIP 重发布/路由策略实验
detectron2,手把手教你训练mask_rcnn
arcgis填坑_3
八股文之并发编程
TOP2两数相加
OA项目之我的审批(查询&会议签字)
HCIP-Spanning Tree (802.1D, Standard Spanning Tree/802.1W: RSTP Rapid Spanning Tree/802.1S: MST Multiple Spanning Tree)
How Xshell connects to a virtual machine
HCIP experiments (pap, chap, HDLC, MGRE, RIP)
(2) Software Testing Theory (*Key Use Case Method Writing)
随机推荐
HPC平台搭建
华为防火墙-5-NAT
uboot code analysis 1: find the main line according to the purpose
Arcgis小工具_实现重叠分析
Threatless Technology-TVD Daily Vulnerability Intelligence-2022-7-19
SECURITY DAY02 (Zabbix alarm mechanism, Zabbix advanced operation and monitoring case)
SATA、SAS、SSD三种硬盘存储性能数据
AUTOMATION DAY06( Ansible进阶 、 Ansible Role)
Numpy_备注
【LeetCode】1036. 逃离大迷宫(思路+题解)压缩矩阵+BFS
window10吐槽
FusionCompute8.0.0实验(0)CNA及VRM安装(2280v2)
arcgis填坑_2
升级到Window11体验
HPC platform building
防火墙-0-管理地址
SECURITY DAY05 (Kali system, scanning and caught, SSH basic protection, service SECURITY)
OA项目之我的会议(会议排座&送审)
空间点模式方法_一阶效应和二阶效应
SECURITY DAY06 ( iptables firewall, filter table control, extended matching, typical application of nat table)