当前位置:网站首页>[Deep learning] pix2pix GAN theory and code implementation
[Deep learning] pix2pix GAN theory and code implementation
2022-08-09 22:05:00 【blameless.lsy】
目录
1.什么是pix2pix GAN
Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件,需要输入到G和D中.G的输入是x(x是需要转换的图片),输出是生成的图片G(x).D则需要分辨出(x,G(x))和(x,y)

pix2pixGANMainly used for conversion between images,Also known as image translation.
2.pix2pixGANBuilder design
对于图像翻译任务来说,A lot of information is shared between input and output.For example, contour information is shared.How to solve sharing problems?We need to think from the design of the loss function.
If you use a normal convolutional neural network,Then it will cause each layer to carry and save all the information.This makes the neural network prone to errors(It is easy to lose some information)
所以,我们使用UNet模型作为生成器

3.pix2pixGAN判别器的设计
DPairs of images to input.这类似于cGAN,如果G(x)和x是对应的,It is desirable for the generator to discriminate as 1;
如果G(x)和x不是对应的,For the generator, it is hoped that the discriminator will discriminate as 0
pix2pixGAN中的DImplemented in the paperpatch_D.所谓patch,It means no matter how big the generated image is,Divide it into multiple fixed-size piecespatch输入进D去判断.如上图所示.
这样设计的好处是:Dinput becomes smaller,计算量小,训练速度快
4.损失函数
D网络损失函数:The input real paired images wish to be judged as 1;The input generated image and the original image want to be judged as0
G网络损失函数:The input generated image and the original image want to be judged as1
公式如下图所示:

For image translation tasks,GIn fact, a lot of information is shared between the input and output of .Therefore, in order to ensure the similarity between the input image and the output image,还加入了L1loss,公式如下所示:
所以,Combine the two formulas,总的损失函数为

5.代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换
import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image
imgs_path = glob.glob('base/*.jpg')
annos_path = glob.glob('base/*.png')
#预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256,256)),
transforms.Normalize(mean=0.5,std=0.5
)
])
#定义数据集
class CMP_dataset(data.Dataset):
def __init__(self,imgs_path,annos_path):
self.imgs_path =imgs_path
self.annos_path = annos_path
def __getitem__(self,index):
img_path = self.imgs_path[index]
anno_path = self.annos_path[index]
pil_img = Image.open(img_path) #读取数据
pil_img = transform(pil_img) #转换数据
anno_img = Image.open(anno_path) #读取数据
anno_img = anno_img.convert("RGB")
pil_anno = transform(anno_img) #转换数据
return pil_anno,pil_img
def __len__(self):
return len(self.imgs_path)
#创建数据集
dataset = CMP_dataset(imgs_path,annos_path)
#将数据转化为dataloader的格式,方便迭代
BATCHSIZE = 32
dataloader = data.DataLoader(dataset,
batch_size = BATCHSIZE,
shuffle = True)
annos_batch,imgs_batch = next(iter(dataloader))
#Define the downsampling module
class Downsample(nn.Module):
def __init__(self,in_channels,out_channels):
super(Downsample,self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels,out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.LeakyReLU(inplace=True))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self,x,is_bn=True):
x=self.conv_relu(x)
if is_bn:
x=self.bn(x)
return x
#定义上采样模块
class Upsample(nn.Module):
def __init__(self,in_channels,out_channels):
super(Upsample,self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1), #反卷积,变为原来的2倍
nn.LeakyReLU(inplace=True))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self,x,is_drop=False):
x=self.upconv_relu(x)
x=self.bn(x)
if is_drop:
x=F.dropout2d(x)
return x
#定义生成器:包含6个下采样,5个上采样,一个输出层
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.down1 = Downsample(3,64) #64,128,128
self.down2 = Downsample(64,128) #128,64,64
self.down3 = Downsample(128,256) #256,32,32
self.down4 = Downsample(256,512) #512,16,16
self.down5 = Downsample(512,512) #512,8,8
self.down6 = Downsample(512,512) #512,4,4
self.up1 = Upsample(512,512) #512,8,8
self.up2 = Upsample(1024,512) #512,16,16
self.up3 = Upsample(1024,256) #256,32,32
self.up4 = Upsample(512,128) #128,64,64
self.up5 = Upsample(256,64) #64,128,128
self.last = nn.ConvTranspose2d(128,3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1) #3,256,256
def forward(self,x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x6 = self.down6(x5)
x6 = self.up1(x6,is_drop=True)
x6 = torch.cat([x6,x5],dim=1)
x6 = self.up2(x6,is_drop=True)
x6 = torch.cat([x6,x4],dim=1)
x6 = self.up3(x6,is_drop=True)
x6 = torch.cat([x6,x3],dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x6,x2],dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x6,x1],dim=1)
x6 = torch.tanh(self.last(x6))
return x6
#定义判别器 输入anno+img(Generated or real) concat
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.down1 = Downsample(6,64)
self.down2 = Downsample(64,128)
self.conv1 = nn.Conv2d(128,256,3)
self.bn = nn.BatchNorm2d(256)
self.last = nn.Conv2d(256,1,3)
def forward(self,anno,img):
x=torch.cat([anno,img],axis =1)
x=self.down1(x,is_bn=False)
x=self.down2(x,is_bn=True)
x=F.dropout2d(self.bn(F.leaky_relu(self.conv1(x))))
x=torch.sigmoid(self.last(x)) #batch*1*60*60
return x
device = "cuda" if torch.cuda.is_available() else'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-3,betas=(0.5,0.999))
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-3,betas=(0.5,0.999))
#绘图
def generate_images(model,test_anno,test_real):
prediction = model(test_anno).permute(0,2,3,1).detach().cpu().numpy()
test_anno = test_anno.permute(0,2,3,1).cpu().numpy()
test_real = test_real.permute(0,2,3,1).cpu().numpy()
plt.figure(figsize = (10,10))
display_list = [test_anno[0],test_real[0],prediction[0]]
title = ['Input','Ground Truth','Output']
for i in range(3):
plt.subplot(1,3,i+1)
plt.title(title[i])
plt.imshow(display_list[i])
plt.axis('off') #Coordinate system is off
plt.show()
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')
test_dataset = CMP_dataset(test_imgs_path,test_annos_path)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCHSIZE,)
#定义损失函数
#cgan 损失函数
loss_fn = torch.nn.BCELoss()
#L1 loss
annos_batch,imgs_batch = annos_batch.to(device),imgs_batch.to(device)
LAMBDA = 7 #L1损失的权重
D_loss = []#记录训练过程中判别器loss变化
G_loss = []#记录训练过程中生成器loss变化
#开始训练
for epoch in range(10):
D_epoch_loss = 0
G_epoch_loss = 0
count = len(dataloader)
for step,(annos,imgs) in enumerate(dataloader):
imgs = imgs.to(device)
annos = annos.to(device)
#Define the loss calculation of the discriminator and the process of optimization
d_optimizer.zero_grad()
disc_real_output = dis(annos,imgs)#Enter real paired images
d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
device=device))
d_real_loss.backward()
gen_output = gen(annos)
disc_gen_output = dis(annos,gen_output.detach())
d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
device=device))
d_fack_loss.backward()
disc_loss = d_real_loss+d_fack_loss#The loss calculation of the discriminator
d_optimizer.step()
#Define the loss calculation of the generator and the process of optimization
g_optimizer.zero_grad()
disc_gen_out = dis(annos,gen_output)
gen_loss_crossentropyloss = loss_fn(disc_gen_out,
torch.ones_like(disc_gen_out,
device=device))
gen_l1_loss = torch.mean(torch.abs(gen_output-imgs))
gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
gen_loss.backward() #反向传播
g_optimizer.step() #优化
#累计每一个批次的loss
with torch.no_grad():
D_epoch_loss +=disc_loss.item()
G_epoch_loss +=gen_loss.item()
#求平均损失
with torch.no_grad():
D_epoch_loss /=count
G_epoch_loss /=count
D_loss.append(D_epoch_loss)
G_loss.append(G_epoch_loss)
#训练完一个Epoch,打印提示并绘制生成的图片
print("Epoch:",epoch)
generate_images(gen,annos_batch,imgs_batch)边栏推荐
猜你喜欢
随机推荐
buuctf(探险2)
小满nestjs(第五章 nestjs cli)
Redis 大的情况下,key 要如何处理?
mysql死锁的排查和解决
ebook下载 | 《 企业高管IT战略指南——企业为何要落地DevOps》
基于SSM实现手机销售商城系统
laravel 时区问题timezone
competed中访问ref为undefined
MYSQL记录、自用
pat链表专题训练+搜索专题
以技术创新加速国家“碳中和”建设进程,华为云创新中心助力欣冠精密实现云智控“气”
华为云全流程护航《流浪方舟》破竹首发,打造口碑爆款
【NOI模拟赛】防AK题(生成函数,单位根,Pollard-Rho)
切绳子【洛谷P1577】【二分】
软件测试技术之如何编写测试用例(6)
2021 RoboCom 世界机器人开发者大赛-本科组(决赛)
STM32WB55的FUS更新及协议栈固件烧写方法
2.3 监督学习-2
下秒数据:湖仓一体带来的现代数据堆栈变革开始了
MFC tutorial
![[Free column] Xposed plug-in development for Android security [from scratch] tutorial](/img/7b/a036ac664c7e27ed7d87e7ee18c05d.png)








