Code for "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks"

Overview

Note: this repo has been discontinued, please check code for newer version of the paper here

Weight Normalized GAN

Code for the paper "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks".

About the code

Here two versions are provided, one for torch and one for PyTorch.

The code used for the experiments in the paper was in torch and was a bit messy, with hand written backward pass of weight normalized layers and other staff used to test various ideas about GANs that are unrelated to the paper. So we decided to clean up the code and port it to PyTorch (read: autograd). However, we are not able to exactly reproduce the results in the paper with the PyTorch code. So we had to port it back to torch to see the difference.

We did find and fix a mathematical bug in gradient computation (Ouch!) in our implementation of weight normalization, which means that the code used for the paper was incorrect and you might not be able to exactly reproduce the results in the paper with the current code. We need to redo some experiments to make sure everything still works. It seems that now a learning rate of 0.00002 gives very good samples but the speed is not very impressive in the beginning; 0.0001 speeds up training even more than in the paper but give worse samples; 0.00005 balances between the two and also give lower reconstruction loss than in the paper. The example below uses 0.00002.

That being said, we can still find some differences in the samples generated by the two versions of code. We think that the torch version is better, so you are adviced to use that version for training. But you should definitely read the PyTorch version to get a better idea of how our method works. We checked this time that in the torch code, the computed gradients wrt the weight vectors are indeed orthogonal to the weight vectors, so hopefully the difference is not caused by another mathematical bug. It could be a numerical issue since the gradient are not computed in exactly the same way. Or I might have made stupid mistakes as I have been doing machine learning for only half a year. We are still investigating.

Usage

The two versions accept the exact same set of arguments except that there is an additional option to set ID of gpu to use in the torch version.

Before training, you need to prepare the data. For torch you need lmdb.torch for LSUN and cifar.torch for CIFAR-10. Split the dataset into training data and test data with split_data.lua/py. Use --running and --final to set number of test samples for running test and final test respectively.

The LSUN loader creates a cache if there isn't one. It takes some time. The loader for custom dataset from a image folder requires images of each class to be in one subfolder, so if you use say CelebA where there is no classes you need to manually create a dummy class.

To train, run main.lua/py. The only ones you must specify are the --dataset, --dataroot, --save_path and --image_size. By default it trains a vanilla model. Use --norm batch or --norm weight to try different normalizations.

The width and the height of the images are not required to be equal. Nor do they have to be powers of two. They only have to both be even numbers. Image size settings work as follows: if --crop_size is specified or if both --crop_width and --crop_height are specified, the training samples are first cropped to the center. Then, if --width and --height are both specified, the training samples are resized to that size. Otherwise, they are resized so that the aspect ratio is kept and the length of the shorter edge equals --image_size, and then cropped to a square.

If --nlayer is set, that many down/up concolution layers are used. Otherwise such layers are added until the size of the feature map is smaller than 8x8. --nfeature specifies the number of features of the first convolution layer.

Set --load_path to continue a saved training.

To test a trained model, use --final_test. Make sure to also use a larger --test_steps since the default value is for the running test during training. By default it finds the best model in load_path, to use another network, set --net

Read the code to see how other arguments work.

Use plot.lua/py to plot the loss curves. The PyTorch version uses PyGnuplot (it sux).

Example

th main.lua --dataset folder --dataroot /path/to/img_align_celeba --crop_size 160 --image_size 160 --code_size 256 --norm weight --lr 0.00002 --save_path /path/to/save/folder

This should give you something like this in 200,000 iterations: celeba example

Additional notes

The WN model might fail in the first handful of iterations. This happens especially often if the network is deeper (on LSUN). Just restart training. If it get past iteration 5 it should continue to train without trouble. This effect could be reduced by using a smaller learning rate for the first couple of iterations.

Extra stuff

At request, added --ls flag to use least square loss.

Owner
Sitao Xiang
Computer Graphics PhD student at University of Southern California. Twitter: StormRaiser123
Sitao Xiang
🕵 Artificial Intelligence for social control of public administration

Non-tech crash course into Operação Serenata de Amor Tech crash course into Operação Serenata de Amor Contributing with code and tech skills Supportin

Open Knowledge Brasil - Rede pelo Conhecimento Livre 4.4k Dec 31, 2022
Fashion Recommender System With Python

Fashion-Recommender-System Thr growing e-commerce industry presents us with a la

Omkar Gawade 2 Feb 02, 2022
Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet

Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击 决赛第四名(team name: Advers)

51 Dec 01, 2022
Recovering Brain Structure Network Using Functional Connectivity

Recovering-Brain-Structure-Network-Using-Functional-Connectivity Framework: Papers: This repository provides a PyTorch implementation of the models ad

5 Nov 30, 2022
Simple keras FCN Encoder/Decoder model for MS-COCO (food subset) segmentation

FCN_MSCOCO_Food_Segmentation Simple keras FCN Encoder/Decoder model for MS-COCO (food subset) segmentation Input data: [http://mscoco.org/dataset/#ove

Alexander Kalinovsky 11 Jan 08, 2019
Official repository of the paper 'Essentials for Class Incremental Learning'

Essentials for Class Incremental Learning Official repository of the paper 'Essentials for Class Incremental Learning' This Pytorch repository contain

33 Nov 27, 2022
This project intends to use SVM supervised learning to determine whether or not an individual is diabetic given certain attributes.

Diabetes Prediction Using SVM I explore a diabetes prediction algorithm using a Diabetes dataset. Using a Support Vector Machine for my prediction alg

Jeff Shen 1 Jan 14, 2022
Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting

Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting This is the origin Pytorch implementation of Informer in the followin

Haoyi 3.1k Dec 29, 2022
Detecting Potentially Harmful and Protective Suicide-related Content on Twitter

TwitterSuicideML Scripts for reproducing the Machine Learning analysis of the paper: Detecting Potentially Harmful and Protective Suicide-related Cont

3 Oct 17, 2022
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022
A new framework, collaborative cascade prediction based on graph neural networks (CCasGNN) to jointly utilize the structural characteristics, sequence features, and user profiles.

CCasGNN A new framework, collaborative cascade prediction based on graph neural networks (CCasGNN) to jointly utilize the structural characteristics,

5 Apr 29, 2022
A Python wrapper for Google Tesseract

Python Tesseract Python-tesseract is an optical character recognition (OCR) tool for python. That is, it will recognize and "read" the text embedded i

Matthias A Lee 4.6k Jan 05, 2023
“Data Augmentation for Cross-Domain Named Entity Recognition” (EMNLP 2021)

Data Augmentation for Cross-Domain Named Entity Recognition Authors: Shuguang Chen, Gustavo Aguilar, Leonardo Neves and Thamar Solorio This repository

<a href=[email protected]"> 18 Sep 10, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
Natural Intelligence is still a pretty good idea.

Human Learn Machine Learning models should play by the rules, literally. Project Goal Back in the old days, it was common to write rule-based systems.

vincent d warmerdam 641 Dec 26, 2022
ML-Ensemble – high performance ensemble learning

A Python library for high performance ensemble learning ML-Ensemble combines a Scikit-learn high-level API with a low-level computational graph framew

Sebastian Flennerhag 764 Dec 31, 2022
Constructing Neural Network-Based Models for Simulating Dynamical Systems

Constructing Neural Network-Based Models for Simulating Dynamical Systems Note this repo is work in progress prior to reviewing This is a companion re

Christian Møldrup Legaard 21 Nov 25, 2022
ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

ByteTrack超详细教程!训练自己的数据集&&摄像头实时检测跟踪

Double-zh 45 Dec 19, 2022
某学校选课系统GIF验证码数据集 + Baseline模型 + 上下游相关工具

elective-dataset-2021spring 某学校2021春季选课系统GIF验证码数据集(29338张) + 准确率98.4%的Baseline模型 + 上下游相关工具。 数据集采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。 Baseline模型和上下游相关工具采用

xmcp 27 Sep 17, 2021
PyTorch implementation for paper StARformer: Transformer with State-Action-Reward Representations.

StARformer This repository contains the PyTorch implementation for our paper titled StARformer: Transformer with State-Action-Reward Representations.

Jinghuan Shang 14 Dec 09, 2022