当前位置:网站首页>[Deep learning] pix2pix GAN theory and code implementation
[Deep learning] pix2pix GAN theory and code implementation
2022-08-09 22:05:00 【blameless.lsy】
目录
1.什么是pix2pix GAN
Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件,需要输入到G和D中.G的输入是x(x是需要转换的图片),输出是生成的图片G(x).D则需要分辨出(x,G(x))和(x,y)
pix2pixGANMainly used for conversion between images,Also known as image translation.
2.pix2pixGANBuilder design
对于图像翻译任务来说,A lot of information is shared between input and output.For example, contour information is shared.How to solve sharing problems?We need to think from the design of the loss function.
If you use a normal convolutional neural network,Then it will cause each layer to carry and save all the information.This makes the neural network prone to errors(It is easy to lose some information)
所以,我们使用UNet模型作为生成器
3.pix2pixGAN判别器的设计
DPairs of images to input.这类似于cGAN,如果G(x)和x是对应的,It is desirable for the generator to discriminate as 1;
如果G(x)和x不是对应的,For the generator, it is hoped that the discriminator will discriminate as 0
pix2pixGAN中的DImplemented in the paperpatch_D.所谓patch,It means no matter how big the generated image is,Divide it into multiple fixed-size piecespatch输入进D去判断.如上图所示.
这样设计的好处是:Dinput becomes smaller,计算量小,训练速度快
4.损失函数
D网络损失函数:The input real paired images wish to be judged as 1;The input generated image and the original image want to be judged as0
G网络损失函数:The input generated image and the original image want to be judged as1
公式如下图所示:
For image translation tasks,GIn fact, a lot of information is shared between the input and output of .Therefore, in order to ensure the similarity between the input image and the output image,还加入了L1loss,公式如下所示:
所以,Combine the two formulas,总的损失函数为
5.代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换
import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image
imgs_path = glob.glob('base/*.jpg')
annos_path = glob.glob('base/*.png')
#预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256,256)),
transforms.Normalize(mean=0.5,std=0.5
)
])
#定义数据集
class CMP_dataset(data.Dataset):
def __init__(self,imgs_path,annos_path):
self.imgs_path =imgs_path
self.annos_path = annos_path
def __getitem__(self,index):
img_path = self.imgs_path[index]
anno_path = self.annos_path[index]
pil_img = Image.open(img_path) #读取数据
pil_img = transform(pil_img) #转换数据
anno_img = Image.open(anno_path) #读取数据
anno_img = anno_img.convert("RGB")
pil_anno = transform(anno_img) #转换数据
return pil_anno,pil_img
def __len__(self):
return len(self.imgs_path)
#创建数据集
dataset = CMP_dataset(imgs_path,annos_path)
#将数据转化为dataloader的格式,方便迭代
BATCHSIZE = 32
dataloader = data.DataLoader(dataset,
batch_size = BATCHSIZE,
shuffle = True)
annos_batch,imgs_batch = next(iter(dataloader))
#Define the downsampling module
class Downsample(nn.Module):
def __init__(self,in_channels,out_channels):
super(Downsample,self).__init__()
self.conv_relu = nn.Sequential(
nn.Conv2d(in_channels,out_channels,
kernel_size=3,
stride=2,
padding=1),
nn.LeakyReLU(inplace=True))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self,x,is_bn=True):
x=self.conv_relu(x)
if is_bn:
x=self.bn(x)
return x
#定义上采样模块
class Upsample(nn.Module):
def __init__(self,in_channels,out_channels):
super(Upsample,self).__init__()
self.upconv_relu = nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1), #反卷积,变为原来的2倍
nn.LeakyReLU(inplace=True))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self,x,is_drop=False):
x=self.upconv_relu(x)
x=self.bn(x)
if is_drop:
x=F.dropout2d(x)
return x
#定义生成器:包含6个下采样,5个上采样,一个输出层
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.down1 = Downsample(3,64) #64,128,128
self.down2 = Downsample(64,128) #128,64,64
self.down3 = Downsample(128,256) #256,32,32
self.down4 = Downsample(256,512) #512,16,16
self.down5 = Downsample(512,512) #512,8,8
self.down6 = Downsample(512,512) #512,4,4
self.up1 = Upsample(512,512) #512,8,8
self.up2 = Upsample(1024,512) #512,16,16
self.up3 = Upsample(1024,256) #256,32,32
self.up4 = Upsample(512,128) #128,64,64
self.up5 = Upsample(256,64) #64,128,128
self.last = nn.ConvTranspose2d(128,3,
kernel_size=3,
stride=2,
padding=1,
output_padding=1) #3,256,256
def forward(self,x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x6 = self.down6(x5)
x6 = self.up1(x6,is_drop=True)
x6 = torch.cat([x6,x5],dim=1)
x6 = self.up2(x6,is_drop=True)
x6 = torch.cat([x6,x4],dim=1)
x6 = self.up3(x6,is_drop=True)
x6 = torch.cat([x6,x3],dim=1)
x6 = self.up4(x6)
x6 = torch.cat([x6,x2],dim=1)
x6 = self.up5(x6)
x6 = torch.cat([x6,x1],dim=1)
x6 = torch.tanh(self.last(x6))
return x6
#定义判别器 输入anno+img(Generated or real) concat
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.down1 = Downsample(6,64)
self.down2 = Downsample(64,128)
self.conv1 = nn.Conv2d(128,256,3)
self.bn = nn.BatchNorm2d(256)
self.last = nn.Conv2d(256,1,3)
def forward(self,anno,img):
x=torch.cat([anno,img],axis =1)
x=self.down1(x,is_bn=False)
x=self.down2(x,is_bn=True)
x=F.dropout2d(self.bn(F.leaky_relu(self.conv1(x))))
x=torch.sigmoid(self.last(x)) #batch*1*60*60
return x
device = "cuda" if torch.cuda.is_available() else'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-3,betas=(0.5,0.999))
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-3,betas=(0.5,0.999))
#绘图
def generate_images(model,test_anno,test_real):
prediction = model(test_anno).permute(0,2,3,1).detach().cpu().numpy()
test_anno = test_anno.permute(0,2,3,1).cpu().numpy()
test_real = test_real.permute(0,2,3,1).cpu().numpy()
plt.figure(figsize = (10,10))
display_list = [test_anno[0],test_real[0],prediction[0]]
title = ['Input','Ground Truth','Output']
for i in range(3):
plt.subplot(1,3,i+1)
plt.title(title[i])
plt.imshow(display_list[i])
plt.axis('off') #Coordinate system is off
plt.show()
test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')
test_dataset = CMP_dataset(test_imgs_path,test_annos_path)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCHSIZE,)
#定义损失函数
#cgan 损失函数
loss_fn = torch.nn.BCELoss()
#L1 loss
annos_batch,imgs_batch = annos_batch.to(device),imgs_batch.to(device)
LAMBDA = 7 #L1损失的权重
D_loss = []#记录训练过程中判别器loss变化
G_loss = []#记录训练过程中生成器loss变化
#开始训练
for epoch in range(10):
D_epoch_loss = 0
G_epoch_loss = 0
count = len(dataloader)
for step,(annos,imgs) in enumerate(dataloader):
imgs = imgs.to(device)
annos = annos.to(device)
#Define the loss calculation of the discriminator and the process of optimization
d_optimizer.zero_grad()
disc_real_output = dis(annos,imgs)#Enter real paired images
d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
device=device))
d_real_loss.backward()
gen_output = gen(annos)
disc_gen_output = dis(annos,gen_output.detach())
d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
device=device))
d_fack_loss.backward()
disc_loss = d_real_loss+d_fack_loss#The loss calculation of the discriminator
d_optimizer.step()
#Define the loss calculation of the generator and the process of optimization
g_optimizer.zero_grad()
disc_gen_out = dis(annos,gen_output)
gen_loss_crossentropyloss = loss_fn(disc_gen_out,
torch.ones_like(disc_gen_out,
device=device))
gen_l1_loss = torch.mean(torch.abs(gen_output-imgs))
gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
gen_loss.backward() #反向传播
g_optimizer.step() #优化
#累计每一个批次的loss
with torch.no_grad():
D_epoch_loss +=disc_loss.item()
G_epoch_loss +=gen_loss.item()
#求平均损失
with torch.no_grad():
D_epoch_loss /=count
G_epoch_loss /=count
D_loss.append(D_epoch_loss)
G_loss.append(G_epoch_loss)
#训练完一个Epoch,打印提示并绘制生成的图片
print("Epoch:",epoch)
generate_images(gen,annos_batch,imgs_batch)
边栏推荐
猜你喜欢
数据集成API如何成为企业数字化转型的关键?
source install/setup.bash时出现错误
【图文并茂】如何进行Win7系统的重装
没有 accept,我可以建立 TCP 连接吗?
力扣15-三数之和——HashSet&双指针法
访问控制知识
真香|持一建证书央企可破格录取
IS31FL3737B general 12 x 12 LED drive 40 QFN I2C 42 ma
源码编译安装与yum和rpm软件安装详解
启动 CM agent 报错——ImportError: libssl.so.10: cannot open shared object file: No such file or directory
随机推荐
【kali-权限提升】(4.2.6)社会工程学工具包(中):中间人攻击工具Ettercap
matlab 神经网络 ANN 分类
请问一下flink cdc mysql source 报这种错怎么处理呢?我都设置了useSSL=f
STM32WB55的FUS更新及协议栈固件烧写方法
ebook download | "Business executives' IT strategy guide - why enterprises should implement DevOps"
[Free column] APK dynamic reverse application of Android security [Three Smali injection methods]
Queue topic: Implementing stacks with queues
软件测试技术之如何编写测试用例(6)
Flume (六) --------- Flume 数据流监控
[] free column Android dynamic debugging GDB APP of safety
laravel之phpunit单元测试
AttributeError: module 'click' has no attribute 'get_os_args'
时序攻击
OpenSSL SSL_read: Connection was reset, errno 10054
基于CC2530 E18-MS1-PCB Zigbee DIY作品(二)
《评估、创建和使用知识图谱的限制》2022最新230页博士论文,根特大学
面试官:Redis 大 key 要如何处理?
真香|持一建证书央企可破格录取
看完这波 Android 面试题;助你斩获心中 offer
如何从800万数据中快速捞出自己想要的数据?