当前位置:网站首页>[Deep learning] pix2pix GAN theory and code implementation

[Deep learning] pix2pix GAN theory and code implementation

2022-08-09 22:05:00 blameless.lsy

目录

1.什么是pix2pix GAN

2.pix2pixGANBuilder design

3.pix2pixGAN判别器的设计

4.损失函数

5.代码实现 


1.什么是pix2pix GAN

Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件,需要输入到G和D中.G的输入是x(x是需要转换的图片),输出是生成的图片G(x).D则需要分辨出(x,G(x))和(x,y)

pix2pixGANMainly used for conversion between images,Also known as image translation.

2.pix2pixGANBuilder design

对于图像翻译任务来说,A lot of information is shared between input and output.For example, contour information is shared.How to solve sharing problems?We need to think from the design of the loss function.

If you use a normal convolutional neural network,Then it will cause each layer to carry and save all the information.This makes the neural network prone to errors(It is easy to lose some information)

所以,我们使用UNet模型作为生成器

3.pix2pixGAN判别器的设计

DPairs of images to input.这类似于cGAN,如果G(x)和x是对应的,It is desirable for the generator to discriminate as 1;

如果G(x)和x不是对应的,For the generator, it is hoped that the discriminator will discriminate as 0

pix2pixGAN中的DImplemented in the paperpatch_D.所谓patch,It means no matter how big the generated image is,Divide it into multiple fixed-size piecespatch输入进D去判断.如上图所示.

这样设计的好处是:Dinput becomes smaller,计算量小,训练速度快

4.损失函数

D网络损失函数:The input real paired images wish to be judged as 1;The input generated image and the original image want to be judged as0

G网络损失函数:The input generated image and the original image want to be judged as1

公式如下图所示: 

For image translation tasks,GIn fact, a lot of information is shared between the input and output of .Therefore, in order to ensure the similarity between the input image and the output image,还加入了L1loss,公式如下所示:

 所以,Combine the two formulas,总的损失函数为

5.代码实现 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision #加载图片
from torchvision import transforms #图片变换

import numpy as np
import matplotlib.pyplot as plt #绘图
import os
import glob
from PIL import Image

imgs_path = glob.glob('base/*.jpg')
annos_path = glob.glob('base/*.png')
#预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256,256)),
    transforms.Normalize(mean=0.5,std=0.5
                )
    
])

#定义数据集
class CMP_dataset(data.Dataset):
    def __init__(self,imgs_path,annos_path):
        self.imgs_path =imgs_path
        self.annos_path = annos_path
    def __getitem__(self,index): 
        img_path = self.imgs_path[index]
        anno_path = self.annos_path[index]
        pil_img = Image.open(img_path) #读取数据
        pil_img = transform(pil_img)  #转换数据
        anno_img = Image.open(anno_path) #读取数据
        anno_img = anno_img.convert("RGB")
        pil_anno = transform(anno_img)  #转换数据
        return pil_anno,pil_img
    def __len__(self):
        return len(self.imgs_path)

#创建数据集
dataset = CMP_dataset(imgs_path,annos_path)
#将数据转化为dataloader的格式,方便迭代
BATCHSIZE = 32
dataloader = data.DataLoader(dataset,
                            batch_size = BATCHSIZE,
                            shuffle = True)
annos_batch,imgs_batch = next(iter(dataloader))

#Define the downsampling module
class Downsample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Downsample,self).__init__()
        self.conv_relu = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,
                 kernel_size=3,
                 stride=2,
                 padding=1),
        nn.LeakyReLU(inplace=True))
        self.bn = nn.BatchNorm2d(out_channels)
    def forward(self,x,is_bn=True):
        x=self.conv_relu(x)
        if is_bn:
            x=self.bn(x)
        return x


#定义上采样模块
class Upsample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Upsample,self).__init__()
        self.upconv_relu = nn.Sequential(
        nn.ConvTranspose2d(in_channels,out_channels,
                 kernel_size=3,
                 stride=2,
                 padding=1,
                 output_padding=1), #反卷积,变为原来的2倍
        nn.LeakyReLU(inplace=True))
        self.bn = nn.BatchNorm2d(out_channels)
    def forward(self,x,is_drop=False):
        x=self.upconv_relu(x)
        x=self.bn(x)
        if is_drop:
            x=F.dropout2d(x)
        return x
        

#定义生成器:包含6个下采样,5个上采样,一个输出层
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.down1 = Downsample(3,64)      #64,128,128
        self.down2 = Downsample(64,128)    #128,64,64
        self.down3 = Downsample(128,256)   #256,32,32
        self.down4 = Downsample(256,512)   #512,16,16
        self.down5 = Downsample(512,512)   #512,8,8
        self.down6 = Downsample(512,512)   #512,4,4
        
        self.up1 = Upsample(512,512)    #512,8,8
        self.up2 = Upsample(1024,512)   #512,16,16
        self.up3 = Upsample(1024,256)   #256,32,32
        self.up4 = Upsample(512,128)    #128,64,64
        self.up5 = Upsample(256,64)     #64,128,128
        
        self.last = nn.ConvTranspose2d(128,3,
                                      kernel_size=3,
                                      stride=2,
                                      padding=1,
                                      output_padding=1)  #3,256,256
        
    def forward(self,x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)
        
        x6 = self.up1(x6,is_drop=True)
        x6 = torch.cat([x6,x5],dim=1)
        
        x6 = self.up2(x6,is_drop=True)
        x6 = torch.cat([x6,x4],dim=1)
        
        x6 = self.up3(x6,is_drop=True)
        x6 = torch.cat([x6,x3],dim=1)
        
        x6 = self.up4(x6)
        x6 = torch.cat([x6,x2],dim=1)
        
        x6 = self.up5(x6)
        x6 = torch.cat([x6,x1],dim=1)
        
        
        x6 = torch.tanh(self.last(x6))
        return x6

#定义判别器  输入anno+img(Generated or real)  concat
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.down1 = Downsample(6,64)
        self.down2 = Downsample(64,128)
        self.conv1 = nn.Conv2d(128,256,3)
        self.bn = nn.BatchNorm2d(256)
        self.last = nn.Conv2d(256,1,3)
    def forward(self,anno,img):
        x=torch.cat([anno,img],axis =1) 
        x=self.down1(x,is_bn=False)
        x=self.down2(x,is_bn=True)
        x=F.dropout2d(self.bn(F.leaky_relu(self.conv1(x))))
        x=torch.sigmoid(self.last(x))  #batch*1*60*60
        return x

device = "cuda" if torch.cuda.is_available() else'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
d_optimizer = torch.optim.Adam(dis.parameters(),lr=1e-3,betas=(0.5,0.999))
g_optimizer = torch.optim.Adam(gen.parameters(),lr=1e-3,betas=(0.5,0.999))
#绘图
def generate_images(model,test_anno,test_real):
    prediction = model(test_anno).permute(0,2,3,1).detach().cpu().numpy()
    test_anno = test_anno.permute(0,2,3,1).cpu().numpy()
    test_real = test_real.permute(0,2,3,1).cpu().numpy()
    plt.figure(figsize = (10,10))
    display_list = [test_anno[0],test_real[0],prediction[0]]
    title = ['Input','Ground Truth','Output']
    for i in range(3):
        plt.subplot(1,3,i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off') #Coordinate system is off
    plt.show()

test_imgs_path = glob.glob('extended/*.jpg')
test_annos_path = glob.glob('extended/*.png')

test_dataset = CMP_dataset(test_imgs_path,test_annos_path)

test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCHSIZE,)

#定义损失函数
#cgan 损失函数
loss_fn = torch.nn.BCELoss()
#L1 loss


annos_batch,imgs_batch = annos_batch.to(device),imgs_batch.to(device)
LAMBDA = 7  #L1损失的权重

D_loss = []#记录训练过程中判别器loss变化
G_loss = []#记录训练过程中生成器loss变化

#开始训练
for epoch in range(10):
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)
    for step,(annos,imgs) in enumerate(dataloader):
        imgs = imgs.to(device)
        annos = annos.to(device)
        #Define the loss calculation of the discriminator and the process of optimization
        d_optimizer.zero_grad()
        disc_real_output = dis(annos,imgs)#Enter real paired images
        d_real_loss = loss_fn(disc_real_output,torch.ones_like(disc_real_output,
                                                             device=device))
        d_real_loss.backward()
        
        gen_output = gen(annos)
        disc_gen_output = dis(annos,gen_output.detach())
        d_fack_loss = loss_fn(disc_gen_output,torch.zeros_like(disc_gen_output,
                                                              device=device))
        d_fack_loss.backward()
        
        disc_loss = d_real_loss+d_fack_loss#The loss calculation of the discriminator
        d_optimizer.step()
        
        #Define the loss calculation of the generator and the process of optimization
        g_optimizer.zero_grad()
        disc_gen_out = dis(annos,gen_output)
        gen_loss_crossentropyloss = loss_fn(disc_gen_out,
                                            torch.ones_like(disc_gen_out,
                                                              device=device))
        gen_l1_loss = torch.mean(torch.abs(gen_output-imgs))
        gen_loss = gen_loss_crossentropyloss +LAMBDA*gen_l1_loss
        gen_loss.backward() #反向传播
        g_optimizer.step() #优化
        
        #累计每一个批次的loss
        with torch.no_grad():
            D_epoch_loss +=disc_loss.item()
            G_epoch_loss +=gen_loss.item()
    #求平均损失
    with torch.no_grad():
            D_epoch_loss /=count
            G_epoch_loss /=count
            D_loss.append(D_epoch_loss)
            G_loss.append(G_epoch_loss)
            #训练完一个Epoch,打印提示并绘制生成的图片
            print("Epoch:",epoch)
            generate_images(gen,annos_batch,imgs_batch)
原网站

版权声明
本文为[blameless.lsy]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/221/202208091903264411.html