Ensembling Off-the-shelf Models for GAN Training

Overview

Vision-aided GAN

video (3m) | website | paper







Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?

We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.

Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
arXiv 2112.09130, 2021

Quantitative Comparison


Our method outperforms recent GAN training methods by a large margin, especially in limited sample setting. For LSUN Cat, we achieve similar FID as StyleGAN2 trained on the full dataset using only $0.7%$ of the dataset. On the full dataset, our method improves FID by 1.5x to 2x on cat, church, and horse categories of LSUN.

Example Results

Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the same randomly sample latent code.

Interpolation Videos

Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).


Requirements

  • 64-bit Python 3.8 and PyTorch 1.8.0 (or later). See https://pytorch.org/ for PyTorch install instructions.
  • Cuda toolkit 11.0 or later.
  • python libraries: see requirements.txt
  • StyleGAN2 code relies heavily on custom PyTorch extensions. For detail please refer to the repo stylegan2-ada-pytorch

Setting up Off-the-shelf Computer Vision models

CLIP(ViT): we modify the model.py function to return intermediate features of the transformer model. To set up follow these steps.

git clone https://github.com/openai/CLIP.git
cp vision-aided-gan/training/clip_model.py CLIP/clip/model.py
cd CLIP
python setup.py install

DINO(ViT): model is automatically downloaded from torch hub.

VGG-16: model is automatically downloaded.

Swin-T(MoBY): Create a pretrained-models directory and save the downloaded model there.

Swin-T(Object Detection): follow the below step for setup. Download the model here and save it in the pretrained-models directory.

git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
python setup.py install

for more details on mmcv installation please refer here

Swin-T(Segmentation): follow the below step for setup. Download the model here and save it in the pretrained-models directory.

git clone https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation.git
cd Swin-Transformer-Semantic-Segmentation
python setup.py install

Face Parsing:download the model here and save in the pretrained-models directory.

Face Normals:download the model here and save in the pretrained-models directory.

Pretrained Models

Our final trained models can be downloaded at this link

To generate images:

python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 --network=<network.pkl>

The output is stored in out directory controlled by --outdir. Our generator architecture is same as styleGAN2 and can be similarly used in the Python code as described in stylegan2-ada-pytorch.

model evaluation:

python calc_metrics.py --network <network.pkl> --metrics fid50k_full --data <dataset> --clean 1

We use clean-fid library to calculate FID metric. For LSUN Church and LSUN Horse, we calclate the full real distribution statistics. For details on calculating the real distribution statistics, please refer to clean-fid. For default FID evaluation of StyleGAN2-ADA use clean=0.

Datasets

Dataset preparation is same as given in stylegan2-ada-pytorch. Example setup for LSUN Church

LSUN Church

git clone https://github.com/fyu/lsun.git
cd lsun
python3 download.py -c church_outdoor
unzip church_outdoor_train_lmdb.zip
cd ../vision-aided-gan
python dataset_tool.py --source <path-to>/church_outdoor_train_lmdb/ --dest <path-to-datasets>/church1k.zip --max-images 1000  --transform=center-crop --width=256 --height=256

datasets can be downloaded from their repsective websites:

FFHQ, LSUN Categories, AFHQ, AnimalFace Dog, AnimalFace Cat, 100-shot Bridge-of-Sighs

Training new networks

model selection: returns the computer vision model with highest linear probe accuracy for the best FID model in a folder or the given network file.

python model_selection.py --data mydataset.zip --network  <mynetworkfolder or mynetworkpklfile>

example training command for training with a single pretrained network from scratch

python train.py --outdir=training-models/ --data=mydataset.zip --gpus 2 --metrics fid50k_full --kimg 25000 --cfg paper256 --cv input-dino-output-conv_multi_level --cv-loss multilevel_s --augcv ada --ada-target-cv 0.3 --augpipecv bgc --batch 16 --mirror 1 --aug ada --augpipe bgc --snap 25 --warmup 1  

Training configuration corresponding to training with vision-aided-loss:

  • --cv=input-dino-output-conv_multi_level pretrained network and its configuration.
  • --warmup=0 should be enabled when training from scratch. Introduces our loss after training with 500k images.
  • --cv-loss=multilevel what loss to use on pretrained model based discriminator.
  • --augcv=ada performs ADA augmentation on pretrained model based discriminator.
  • --augcv=diffaugment-<policy> performs DiffAugment on pretrained model based discriminator with given poilcy.
  • --augpipecv=bgc ADA augmentation strategy. Note: cutout is always enabled.
  • --ada-target-cv=0.3 adjusts ADA target value for pretrained model based discriminator.
  • --exact-resume=0 enables exact resume along with optimizer state.

Miscellaneous configurations:

  • --appendname='' additional string to append to training directory name.
  • --wandb-log=0 enables wandb logging.
  • --clean=0 enables FID calculation using clean-fid if the real distribution statistics are pre-calculated.

Run python train.py --help for more details and the full list of args.

References

@article{kumari2021ensembling,
  title={Ensembling Off-the-shelf Models for GAN Training},
  author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
  journal={arXiv preprint arXiv:2112.09130},
  year={2021}
}

Acknowledgments

We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.

This is the repository for the NeurIPS-21 paper [Contrastive Graph Poisson Networks: Semi-Supervised Learning with Extremely Limited Labels].

CGPN This is the repository for the NeurIPS-21 paper [Contrastive Graph Poisson Networks: Semi-Supervised Learning with Extremely Limited Labels]. Req

10 Sep 12, 2022
Official code for the paper "Self-Supervised Prototypical Transfer Learning for Few-Shot Classification"

Self-Supervised Prototypical Transfer Learning for Few-Shot Classification This repository contains the reference source code and pre-trained models (

EPFL INDY 44 Nov 04, 2022
Code for the paper "Adapting Monolingual Models: Data can be Scarce when Language Similarity is High"

Wietse de Vries • Martijn Bartelds • Malvina Nissim • Martijn Wieling Adapting Monolingual Models: Data can be Scarce when Language Similarity is High

Wietse de Vries 5 Aug 02, 2021
A new data augmentation method for extreme lighting conditions.

Random Shadows and Highlights This repo has the source code for the paper: Random Shadows and Highlights: A new data augmentation method for extreme l

Osama Mazhar 35 Nov 26, 2022
验证码识别 深度学习 tensorflow 神经网络

captcha_tf2 验证码识别 深度学习 tensorflow 神经网络 使用卷积神经网络,对字符,数字类型验证码进行识别,tensorflow使用2.0以上 目前项目还在更新中,诸多bug,欢迎提出issue和PR, 希望和你一起共同完善项目。 实例demo 训练过程 优化器选择: Adam

5 Apr 28, 2022
FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data

FEMDA: Robust classification with Flexible Discriminant Analysis in heterogeneous data. Flexible EM-Inspired Discriminant Analysis is a robust supervised classification algorithm that performs well i

0 Sep 06, 2022
A minimalist tool to display a network graph.

A tool to get a minimalist view of any architecture This tool has only be tested with the models included in this repo. Therefore, I can't guarantee t

Thibault Castells 1 Feb 11, 2022
BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training

BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training By Likun Cai, Zhi Zhang, Yi Zhu, Li Zhang, Mu Li, Xiangyang Xue. This

290 Dec 29, 2022
Mask2Former: Masked-attention Mask Transformer for Universal Image Segmentation in TensorFlow 2

Mask2Former: Masked-attention Mask Transformer for Universal Image Segmentation in TensorFlow 2 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexan

Phan Nguyen 1 Dec 16, 2021
A Topic Modeling toolbox

Topik A Topic Modeling toolbox. Introduction The aim of topik is to provide a full suite and high-level interface for anyone interested in applying to

Anaconda, Inc. (formerly Continuum Analytics, Inc.) 93 Dec 01, 2022
UMPNet: Universal Manipulation Policy Network for Articulated Objects

UMPNet: Universal Manipulation Policy Network for Articulated Objects Zhenjia Xu, Zhanpeng He, Shuran Song Columbia University Robotics and Automation

Columbia Artificial Intelligence and Robotics Lab 33 Dec 03, 2022
Data Preparation, Processing, and Visualization for MoVi Data

MoVi-Toolbox Data Preparation, Processing, and Visualization for MoVi Data, https://www.biomotionlab.ca/movi/ MoVi is a large multipurpose dataset of

Saeed Ghorbani 51 Nov 27, 2022
I explore rock vs. mine prediction using a SONAR dataset

I explore rock vs. mine prediction using a SONAR dataset. Using a Logistic Regression Model for my prediction algorithm, I intend on predicting what an object is based on supervised learning.

Jeff Shen 1 Jan 11, 2022
Code for Paper Predicting Osteoarthritis Progression via Unsupervised Adversarial Representation Learning

Predicting Osteoarthritis Progression via Unsupervised Adversarial Representation Learning (c) Tianyu Han and Daniel Truhn, RWTH Aachen University, 20

Tianyu Han 7 Nov 22, 2022
[CVPR'22] Weakly Supervised Semantic Segmentation by Pixel-to-Prototype Contrast

wseg Overview The Pytorch implementation of Weakly Supervised Semantic Segmentation by Pixel-to-Prototype Contrast. [arXiv] Though image-level weakly

Ye Du 96 Dec 30, 2022
Official implementation of ACMMM'20 paper 'Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework'

Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework Official code for paper, Self-supervised Video Representation Le

Li Tao 103 Dec 21, 2022
Codebase for the paper titled "Continual learning with local module selection"

This repository contains the codebase for the paper Continual Learning via Local Module Composition. Setting up the environemnt Create a new conda env

Oleksiy Ostapenko 20 Dec 10, 2022
Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

causal-bald | Abstract | Installation | Example | Citation | Reproducing Results DUE An implementation of the methods presented in Causal-BALD: Deep B

OATML 13 Oct 07, 2022
KakaoBrain KoGPT (Korean Generative Pre-trained Transformer)

KoGPT KoGPT (Korean Generative Pre-trained Transformer) https://github.com/kakaobrain/kogpt https://huggingface.co/kakaobrain/kogpt Model Descriptions

Kakao Brain 799 Dec 28, 2022