当前位置:网站首页>[Deep learning] pix2pix GAN theory and code implementation
[Deep learning] pix2pix GAN theory and code implementation
2022-08-09 22:05:00 【blameless.lsy】
目录
1.什么是pix2pix GAN
Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件,需要输入到G和D中.G的输入是x(x是需要转换的图片),输出是生成的图片G(x).D则需要分辨出(x,G(x))和(x,y)
pix2pixGANMainly used for conversion between images,Also known as image translation.
2.pix2pixGANBuilder design
对于图像翻译任务来说,A lot of information is shared between input and output.For example, contour information is shared.How to solve sharing problems?We need to think from the design of the loss function.
If you use a normal convolutional neural network,Then it will cause each layer to carry and save all the information.This makes the neural network prone to errors(It is easy to lose some information)
所以,我们使用UNet模型作为生成器
3.pix2pixGAN判别器的设计
DPairs of images to input.这类似于cGAN,如果G(x)和x是对应的,It is desirable for the generator to discriminate as 1;
如果G(x)和x不是对应的,For the generator, it is hoped that the discriminator will discriminate as 0
pix2pixGAN中的DImplemented in the paperpatch_D.所谓patch,It means no matter how big the generated image is,Divide it into multiple fixed-size piecespatch输入进D去判断.如上图所示.
这样设计的好处是:Dinput becomes smaller,计算量小,训练速度快
4.损失函数
D网络损失函数:The input real paired images wish to be judged as 1;The input generated image and the original image want to be judged as0
G网络损失函数:The input generated image and the original image want to be judged as1
公式如下图所示:
For image translation tasks,GIn fact, a lot of information is shared between the input and output of .Therefore, in order to ensure the similarity between the input image and the output image,还加入了L1loss,公式如下所示:
所以,Combine the two formulas,总的损失函数为
5.代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换
import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image
imgs_path = glob.glob('base/*.jpg')
annos_path = glob.glob('base/*.png')
#预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256,256)),
transforms.Normalize(mean=0.5,std=0.5
)
])
#定义数据集
class CMP_dataset(data.Dataset):
def __init__(self,imgs_path,annos_path):
self.imgs_path =imgs_path
self.annos_path = annos_path
def __getitem__(self,index):
img_path = self.imgs_path[index]
anno_path = self.annos_path[index]
pil_img = Image.open(img_path) #读取数据
pil_img = transform(pil_img) #转换数据
anno_img = Image.open(anno_path) #读取数据
anno_img = anno_img.convert("RGB")
pil_anno = transform(anno_img) #转换数据
return pil_anno,pil_img
def __len__(self):
return len(self.imgs_path)
#创建数据集
dataset = CMP_dataset(imgs_path,annos_path)
#将数据转化为dataloader的格式,方便迭代
BATCHSIZE = 32
dataloader = data.DataLoader(dataset,
batch_size = BATCHSIZE,
shuffle = True)
annos_batch,imgs_batch = next(iter(dataloader))
#Define the downsampling module
class Downsample(nn.Module):
def __init__(self,in_channels,out_channels):
super(Downsample,self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels,out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.LeakyReLU(inplace=True))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self,x,is_bn=True):
x=self.conv_relu(x)
if is_bn:
x=self.bn(x)
return x
#定义上采样模块
class Upsample(nn.Module):
def __init__(self,in_channels,out_channels):
super(Upsample,self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1), #反卷积,变为原来的2倍
nn.LeakyReLU(inplace=True))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self,x,is_drop=False):
x=self.upconv_relu(x)
x=self.bn(x)
if is_drop:
x=F.dropout2d(x)
return x
#定义生成器:包含6个下采样,5个上采样,一个输出层
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.down1 = Downsample(3,64) #64,128,128
self.down2 = Downsample(64,128) #128,64,64
self.down3 = Downsample(128,256) #256,32,32
self.down4 = Downsample(256,512) #512,16,16
self.down5 = Downsample(512,512) #512,8,8
self.down6 = Downsample(512,512) #512,4,4
self.up1 = Upsample(512,512) #512,8,8
self.up2 = Upsample(1024,512) #512,16,16
self.up3 = Upsample(1024,256) #256,32,32
self.up4 = Upsample(512,128) #128,64,64
self.up5 = Upsample(256,64) #64,128,128
self.last = nn.ConvTranspose2d(128,3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1) #3,256,256
def forward(self,x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x6 = self.down6(x5)
x6 = self.up1(x6,is_drop=True)
x6 = torch.cat([x6,x5],dim=1)
x6 = self.up2(x6,is_drop=True)
x6 = torch.cat([x6,x4],dim=1)
x6 = self.up3(x6,is_drop=True)
x6 = torch.cat([x6,x3],dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x6,x2],dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x6,x1],dim=1)
x6 = torch.tanh(self.last(x6))
return x6
#定义判别器 输入anno+img(Generated or real) concat
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.down1 = Downsample(6,64)
self.down2 = Downsample(64,128)
self.conv1 = nn.Conv2d(128,256,3)
self.bn = nn.BatchNorm2d(256)
self.last = nn.Conv2d(256,1,3)
def forward(self,anno,img):
x=torch.cat([anno,img],axis =1)
x=self.down1(x,is_bn=False)
x=self.down2(x,is_bn=True)
x=F.dropout2d(self.bn(F.leaky_relu(self.conv1(x))))
x=torch.sigmoid(self.last(x)) #batch*1*60*60
return x
device = "cuda" if torch.cuda.is_available() else'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-3,betas=(0.5,0.999))
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-3,betas=(0.5,0.999))
#绘图
def generate_images(model,test_anno,test_real):
prediction = model(test_anno).permute(0,2,3,1).detach().cpu().numpy()
test_anno = test_anno.permute(0,2,3,1).cpu().numpy()
test_real = test_real.permute(0,2,3,1).cpu().numpy()
plt.figure(figsize = (10,10))
display_list = [test_anno[0],test_real[0],prediction[0]]
title = ['Input','Ground Truth','Output']
for i in range(3):
plt.subplot(1,3,i+1)
plt.title(title[i])
plt.imshow(display_list[i])
plt.axis('off') #Coordinate system is off
plt.show()
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')
test_dataset = CMP_dataset(test_imgs_path,test_annos_path)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCHSIZE,)
#定义损失函数
#cgan 损失函数
loss_fn = torch.nn.BCELoss()
#L1 loss
annos_batch,imgs_batch = annos_batch.to(device),imgs_batch.to(device)
LAMBDA = 7 #L1损失的权重
D_loss = []#记录训练过程中判别器loss变化
G_loss = []#记录训练过程中生成器loss变化
#开始训练
for epoch in range(10):
D_epoch_loss = 0
G_epoch_loss = 0
count = len(dataloader)
for step,(annos,imgs) in enumerate(dataloader):
imgs = imgs.to(device)
annos = annos.to(device)
#Define the loss calculation of the discriminator and the process of optimization
d_optimizer.zero_grad()
disc_real_output = dis(annos,imgs)#Enter real paired images
d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
device=device))
d_real_loss.backward()
gen_output = gen(annos)
disc_gen_output = dis(annos,gen_output.detach())
d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
device=device))
d_fack_loss.backward()
disc_loss = d_real_loss+d_fack_loss#The loss calculation of the discriminator
d_optimizer.step()
#Define the loss calculation of the generator and the process of optimization
g_optimizer.zero_grad()
disc_gen_out = dis(annos,gen_output)
gen_loss_crossentropyloss = loss_fn(disc_gen_out,
torch.ones_like(disc_gen_out,
device=device))
gen_l1_loss = torch.mean(torch.abs(gen_output-imgs))
gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
gen_loss.backward() #反向传播
g_optimizer.step() #优化
#累计每一个批次的loss
with torch.no_grad():
D_epoch_loss +=disc_loss.item()
G_epoch_loss +=gen_loss.item()
#求平均损失
with torch.no_grad():
D_epoch_loss /=count
G_epoch_loss /=count
D_loss.append(D_epoch_loss)
G_loss.append(G_epoch_loss)
#训练完一个Epoch,打印提示并绘制生成的图片
print("Epoch:",epoch)
generate_images(gen,annos_batch,imgs_batch)
边栏推荐
- 小满nestjs(第三章 前置知识装饰器)
- 如何在WPF中设置Grid ColumnDefinitions的样式
- 2021 RoboCom 世界机器人开发者大赛-本科组(决赛)
- Flume (五) --------- 自定义 Interceptor、自定义 Source 与 自定义 Sink
- 启动 CM agent 报错——ImportError: libssl.so.10: cannot open shared object file: No such file or directory
- Haven't tried line art videos this year??
- shell之变量详解,让你秒懂!
- IS31FL3737B general 12 x 12 LED drive 40 QFN I2C 42 ma
- visual studio 2022调试技巧介绍
- 【NOI模拟赛】防AK题(生成函数,单位根,Pollard-Rho)
猜你喜欢
随机推荐
MFC tutorial
数据分散情况的统计图-盒须图
时序攻击
OpenSSL SSL_read: Connection was reset, errno 10054
mysql 重复数据 分组 多条最新的记录
基于CC2530 E18-MS1-PCB Zigbee DIY作品
ebook下载 | 《 企业高管IT战略指南——企业为何要落地DevOps》
Flume (六) --------- Flume 数据流监控
shell脚本编写 hash方法,shell中字符到ascii码或数字的转换
matlab 神经网络 ANN 分类
php删除字符串的空格
毕昇编译器优化:Lazy Code Motion
基于光通信的6G水下信道建模综述
面试官:Redis 大 key 要如何处理?
技术分享 | 接口自动化测试如何处理 Header cookie
【IoT毕设】STM32与机智云自助开发平台的宠物智能喂养系统
laravel之phpunit单元测试
buuctf(探险2)
请问一下flink cdc mysql source 报这种错怎么处理呢?我都设置了useSSL=f
基于Web的疫情隔离区订餐系统