基于pytorch构建cyclegan示例

Overview

cyclegan-demo

基于Pytorch构建CycleGAN示例

如何运行

准备数据集

将数据集整理成4个文件,分别命名为

  • trainA, trainB:训练集,A、B代表两类图片
  • testA, testB:测试集,A、B代表两类图片

例如

D:\CODE\CYCLEGAN-DEMO\DATA\SUMMER2WINTER
├─testA
├─testB
├─trainA
└─trainB

之后在main.py中将root设为数据集的路径。

参数设置

main.py中的初始化参数

# 初始化参数
# seed: 随机种子
# root: 数据集路径
# output_model_root: 模型的输出路径
# image_size: 图片尺寸
# batch_size: 一次喂入的数据量
# lr: 学习率
# betas: 一阶和二阶动量
# epochs: 训练总次数
# historical_epochs: 历史训练次数
# - 0表示不沿用历史模型
# - >0表示对应训练次数的模型
# - -1表示最后一次训练的模型
# save_every: 保存频率
# loss_range: Loss的显示范围
seed = 123
data_root = 'D:/code/cyclegan-demo/data/summer2winter'
output_model_root = 'output/model'
image_size = 64
batch_size = 16
lr = 2e-4
betas = (.5, .999)
epochs = 100
historical_epochs = -1
save_every = 1
loss_range = 1000

安装和运行

  1. 安装依赖
pip install -r requirements.txt
  1. 打开命令行,运行Visdom
python -m visdom.server
  1. 运行主程序
python main.py

训练过程的可视化展示

访问地址http://localhost:8097即可进入Visdom可视化页面,页面中将展示:

  • A类真实图片 -【A2B生成器】 -> B类虚假图片 -【B2A生成器】 -> A类重构图片
  • B类真实图片 -【B2A生成器】 -> A类虚假图片 -【A2B生成器】 -> B类重构图片
  • 判别器A、B以及生成器的Loss曲线

一些可视化的具体用法可见Visdom的使用方法。

测试

TODO

介绍

目录结构

  • dataset.py 数据集
  • discriminator.py 判别器
  • generater.py 生成器
  • main.py 主程序
  • replay_buffer.py 缓冲区
  • resblk.py 残差块
  • util.py 工具方法

原理介绍

残差块是生成器的组成部分,其结构如下

Resblk(
  (main): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (2): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (6): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  )
)

生成器结构如下,由于采用全卷积结构,事实上其结构与图片尺寸无关

Generater(
  (input): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
  )
  (downsampling): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (resnet): Sequential(
    (0): Resblk
    (1): Resblk
    (2): Resblk
    (3): Resblk
    (4): Resblk
    (5): Resblk
    (6): Resblk
    (7): Resblk
    (8): Resblk
  )

  (upsampling): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (output): Sequential(
    (0): ReplicationPad2d((3, 3, 3, 3))
    (1): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
    (2): Tanh()
  )
)

判别器结构如下,池化层具体尺寸由图片尺寸决定,64x64的图片对应池化层为6x6

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
  (output): Sequential(
    (0): AvgPool2d(kernel_size=torch.Size([6, 6]), stride=torch.Size([6, 6]), padding=0)
    (1): Flatten(start_dim=1, end_dim=-1)
  )
)

训练共有三个优化器,分别负责生成器、判别器A、判别器B的优化。

损失有三种类型:

  • 一致性损失:A(B)类真实图片与经生成器生成的图片的误差,该损失使得生成后的风格与原图更接近,采用L1Loss
  • 对抗损失:A(B)类图片经生成器得到B(A)类图片,再经判别器判别的错误率,采用MSELoss
  • 循环损失:A(B)类图片经生成器得到B(A)类图片,再经生成器得到A(B)类的重建图片,原图和重建图片的误差,采用L1Loss

生成器的训练过程:

  1. 将A(B)类真实图片送入生成器,得到生成的图片,计算生成图片与原图的一致性损失
  2. 将A(B)类真实图片送入生成器得到虚假图片,再送入判别器得到判别结果,计算判别结果与真实标签1的对抗损失(虚假图片应能被判别器判别为真实图片,即生成器能骗过判别器)
  3. 将A(B)类虚假图片送入生成器,得到重建图片,计算重建图片与原图的循环损失
  4. 计算、更新梯度

判别器A的训练过程:

  1. 将A类真实图片送入判别器A,得到判别结果,计算判别结果与真实标签1的对抗损失(判别器应将真实图片判别为真实)
  2. 将A类虚假图片送入判别器A,得到判别结果,计算判别结果与虚假标签0的对抗损失(判别器应将虚假图片判别为虚假)
  3. 计算、更新梯度
Owner
Koorye
学习?学个屁
Koorye
BRNet - code for Automated assessment of BI-RADS categories for ultrasound images using multi-scale neural networks with an order-constrained loss function

BRNet code for "Automated assessment of BI-RADS categories for ultrasound images using multi-scale neural networks with an order-constrained loss func

Yong Pi 2 Mar 09, 2022
A hyperparameter optimization framework

Optuna: A hyperparameter optimization framework Website | Docs | Install Guide | Tutorial Optuna is an automatic hyperparameter optimization software

7.4k Jan 04, 2023
This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems.

Amortized Assimilation This repository contains a PyTorch implementation of the paper Learning to Assimilate in Chaotic Dynamical Systems. Abstract: T

4 Aug 16, 2022
Towards Representation Learning for Atmospheric Dynamics (AtmoDist)

Towards Representation Learning for Atmospheric Dynamics (AtmoDist) The prediction of future climate scenarios under anthropogenic forcing is critical

Sebastian Hoffmann 4 Dec 15, 2022
Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Det

123 Jan 04, 2023
Learning kernels to maximize the power of MMD tests

Code for the paper "Generative Models and Model Criticism via Optimized Maximum Mean Discrepancy" (arXiv:1611.04488; published at ICLR 2017), by Douga

Danica J. Sutherland 201 Dec 17, 2022
Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

48 Dec 20, 2022
YouRefIt: Embodied Reference Understanding with Language and Gesture

YouRefIt: Embodied Reference Understanding with Language and Gesture YouRefIt: Embodied Reference Understanding with Language and Gesture by Yixin Che

16 Jul 11, 2022
QHack—the quantum machine learning hackathon

Official repo for QHack—the quantum machine learning hackathon

Xanadu 72 Dec 21, 2022
Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Michael Nielsen 13.9k Dec 26, 2022
Official implementation of EdiTTS: Score-based Editing for Controllable Text-to-Speech

EdiTTS: Score-based Editing for Controllable Text-to-Speech Official implementation of EdiTTS: Score-based Editing for Controllable Text-to-Speech. Au

Neosapience 98 Dec 25, 2022
The dynamics of representation learning in shallow, non-linear autoencoders

The dynamics of representation learning in shallow, non-linear autoencoders The package is written in python and uses the pytorch implementation to ML

Maria Refinetti 4 Jun 08, 2022
PyTorch implementation of popular datasets and models in remote sensing

PyTorch Remote Sensing (torchrs) (WIP) PyTorch implementation of popular datasets and models in remote sensing tasks (Change Detection, Image Super Re

isaac 222 Dec 28, 2022
Training data extraction on GPT-2

Training data extraction from GPT-2 This repository contains code for extracting training data from GPT-2, following the approach outlined in the foll

Florian Tramer 62 Dec 07, 2022
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
Real time Human Detection Counting

In this python project, we are going to build the Human Detection and Counting System through Webcam or you can give your own video or images. This is a deep learning project on computer vision, whic

Mir Nawaz Ahmad 2 Jun 17, 2022
FairMOT - A simple baseline for one-shot multi-object tracking

FairMOT - A simple baseline for one-shot multi-object tracking

Yifu Zhang 3.6k Jan 08, 2023
Tensorflow implementation for Self-supervised Graph Learning for Recommendation

If the compilation is successful, the evaluator of cpp implementation will be called automatically. Otherwise, the evaluator of python implementation will be called.

152 Jan 07, 2023
[Official] Exploring Temporal Coherence for More General Video Face Forgery Detection(ICCV 2021)

Exploring Temporal Coherence for More General Video Face Forgery Detection(FTCN) Yinglin Zheng, Jianmin Bao, Dong Chen, Ming Zeng, Fang Wen Accepted b

57 Dec 28, 2022
Learning Skeletal Articulations with Neural Blend Shapes

This repository provides an end-to-end library for automatic character rigging and blend shapes generation as well as a visualization tool. It is based on our work Learning Skeletal Articulations wit

Peizhuo 504 Dec 30, 2022