Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs)

Overview

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements

[paper (NeurIPS 2021)] [paper (arXiv)] [code]

Authors: Zinan Lin, Vyas Sekar, Giulia Fanti

Abstract: Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs). However, there is currently limited understanding of why SN is effective. In this work, we show that SN controls two important failure modes of GAN training: exploding and vanishing gradients. Our proofs illustrate a (perhaps unintentional) connection with the successful LeCun initialization. This connection helps to explain why the most popular implementation of SN for GANs requires no hyper-parameter tuning, whereas stricter implementations of SN have poor empirical performance out-of-the-box. Unlike LeCun initialization which only controls gradient vanishing at the beginning of training, SN preserves this property throughout training. Building on this theoretical understanding, we propose a new spectral normalization technique: Bidirectional Scaled Spectral Normalization (BSSN), which incorporates insights from later improvements to LeCun initialization: Xavier initialization and Kaiming initialization. Theoretically, we show that BSSN gives better gradient control than SN. Empirically, we demonstrate that it outperforms SN in sample quality and training stability on several benchmark datasets.


This repo contains the codes for reproducing the experiments of our BSN and different SN variants in the paper. The codes were tested under Python 2.7.5, TensorFlow 1.14.0.

Preparing datasets

CIFAR10

Download cifar-10-python.tar.gz from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz (or from other sources).

STL10

Download stl10_binary.tar.gz from http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz (or from other sources), and put it in dataset_preprocess/STL10 folder. Then run python preprocess.py. This code will resize the images into 48x48x3 format, and save the images in stl10.npy.

CelebA

Download img_align_celeba.zip from https://www.kaggle.com/jessicali9530/celeba-dataset (or from other sources), and put it in dataset_preprocess/CelebA folder. Then run python preprocess.py. This code will crop and resize the images into 64x64x3 format, and save the images in celeba.npy.

ImageNet

Download ILSVRC2012_img_train.tar from http://www.image-net.org/ (or from other sources), and put it in dataset_preprocess/ImageNet folder. Then run python preprocess.py. This code will crop and resize the images into 128x128x3 format, and save the images in ILSVRC2012folder. Each subfolder in ILSVRC2012 folder corresponds to one class. Each npy file in the subfolders corresponds to an image.

Training BSN and SN variants

Prerequisites

The codes are based on GPUTaskScheduler library, which helps you automatically schedule the jobs among GPU nodes. Please install it first. You may need to change GPU configurations according to the devices you have. The configurations are set in config.py in each directory. Please refer to GPUTaskScheduler's GitHub page for the details of how to make proper configurations.

You can also run these codes without GPUTaskScheduler. Just run python gan.py in gan subfolders.

CIFAR10, STL10, CelebA

Preparation

Copy the preprocessed datasets from the previous steps into the following paths:

  • CIFAR10: /data/CIFAR10/cifar-10-python.tar.gz.
  • STL10: /data/STL10/cifar-10-stl10.npy.
  • CelebA: /data/CelebA/celeba.npy.

Here means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-CNN.
  • SN with the same gammas: same_gamma-CNN.
  • SN with different gammas: diff_gamma-CNN.

Alternatively, you can directly modify the dataset paths in /gan_task.py to the path of the preprocessed dataset folders.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

ImageNet

Preparation

Copy the preprocessed folder ILSVRC2012 from the previous steps to /data/imagenet/ILSVRC2012, where means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-ResNet.

Alternatively, you can directly modify the dataset path in /gan_task.py to the path of the preprocessed folder ILSVRC2012.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

The code supports multi-GPU training for speed-up, by separating each data batch equally among multiple GPUs. To do that, you only need to make minor modifications in config.py. For example, if you have two GPUs with IDs 0 and 1, then all you need to do is to (1) change "gpu": ["0"] to "gpu": [["0", "1"]], and (2) change "num_gpus": [1] to "num_gpus": [2]. Note that the number of GPUs might influence the results because in this implementation the batch normalization layers on different GPUs are independent. In our experiments, we were using only one GPU.

Results

The code generates the following result files/folders:

  • /results/ /worker.log : Standard output and error from the code.
  • /results/ /metrics.csv : Inception Score and FID during training.
  • /results/ /sample/*.png : Generated images during training.
  • /results/ /checkpoint/* : TensorFlow checkpoints.
  • /results/ /time.txt : Training iteration timestamps.
Owner
Zinan Lin
Ph.D. student at Electrical and Computer Engineering, Carnegie Mellon University
Zinan Lin
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

32 Sep 21, 2022
A generalist algorithm for cell and nucleus segmentation.

Cellpose | A generalist algorithm for cell and nucleus segmentation. Cellpose was written by Carsen Stringer and Marius Pachitariu. To learn about Cel

MouseLand 733 Dec 29, 2022
PyTorch implementation for "Mining Latent Structures with Contrastive Modality Fusion for Multimedia Recommendation"

MIRCO PyTorch implementation for paper: Latent Structures Mining with Contrastive Modality Fusion for Multimedia Recommendation Dependencies Python 3.

Big Data and Multi-modal Computing Group, CRIPAC 9 Dec 08, 2022
TensorFlow 2 AI/ML library wrapper for openFrameworks

ofxTensorFlow2 This is an openFrameworks addon for the TensorFlow 2 ML (Machine Learning) library

Center for Art and Media Karlsruhe 96 Dec 31, 2022
Text Summarization - WCN — Weighted Contextual N-gram method for evaluation of Text Summarization

Text Summarization WCN — Weighted Contextual N-gram method for evaluation of Text Summarization In this project, I fine tune T5 model on Extreme Summa

Aditya Shah 1 Jan 03, 2022
Charsiu: A transformer-based phonetic aligner

Charsiu: A transformer-based phonetic aligner [arXiv] Note. This is a preview version. The aligner is under active development. New functions, new lan

jzhu 166 Dec 09, 2022
Code for our ALiBi method for transformer language models.

Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation This repository contains the code and models for our paper Tra

Ofir Press 211 Dec 31, 2022
Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021.

UniRE Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021. Requirements python: 3.7.6 pytorch: 1.8.1 transformers:

Wang Yijun 109 Nov 29, 2022
PyTorch implementation of the TTC algorithm

Trust-the-Critics This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critic

0 Nov 29, 2021
Using Streamlit to host a multi-page tool with model specs and classification metrics, while also accepting user input values for prediction.

Predicitng_viability Using Streamlit to host a multi-page tool with model specs and classification metrics, while also accepting user input values for

Gopalika Sharma 1 Nov 08, 2021
Add-on for importing and auto setup of character creator 3 character exports.

CC3 Blender Tools An add-on for importing and automatically setting up materials for Character Creator 3 character exports. Using Blender in the Chara

260 Jan 05, 2023
Efficient 3D Backbone Network for Temporal Modeling

VoV3D is an efficient and effective 3D backbone network for temporal modeling implemented on top of PySlowFast. Diverse Temporal Aggregation and

102 Dec 06, 2022
Breaking the Dilemma of Medical Image-to-image Translation

Breaking the Dilemma of Medical Image-to-image Translation Supervised Pix2Pix and unsupervised Cycle-consistency are two modes that dominate the field

Kid Liet 86 Dec 21, 2022
Asymmetric metric learning for knowledge transfer

Asymmetric metric learning This is the official code that enables the reproduction of the results from our paper: Asymmetric metric learning for knowl

20 Dec 06, 2022
Simulation of Self Driving Car

In this repository, the code to use Udacity's self driving car simulator as a testbed for training an autonomous car are provided.

Shyam Das Shrestha 1 Nov 21, 2021
Deep Reinforcement Learning based Trading Agent for Bitcoin

Deep Trading Agent Deep Reinforcement Learning based Trading Agent for Bitcoin using DeepSense Network for Q function approximation. For complete deta

Kartikay Garg 669 Dec 29, 2022
Semantic Bottleneck Scene Generation

SB-GAN Semantic Bottleneck Scene Generation Coupling the high-fidelity generation capabilities of label-conditional image synthesis methods with the f

Samaneh Azadi 41 Nov 28, 2022
[ArXiv 2021] Data-Efficient Instance Generation from Instance Discrimination

InsGen - Data-Efficient Instance Generation from Instance Discrimination Data-Efficient Instance Generation from Instance Discrimination Ceyuan Yang,

GenForce: May Generative Force Be with You 93 Dec 25, 2022
Demo for the paper "Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation"

Streaming speaker diarization Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation by Juan Manuel Coria, Hervé

Juanma Coria 187 Jan 06, 2023
This repository contains Prior-RObust Bayesian Optimization (PROBO) as introduced in our paper "Accounting for Gaussian Process Imprecision in Bayesian Optimization"

Prior-RObust Bayesian Optimization (PROBO) Introduction, TOC This repository contains Prior-RObust Bayesian Optimization (PROBO) as introduced in our

Julian Rodemann 2 Mar 19, 2022