Invert and perturb GAN images for test-time ensembling

Overview

GAN Ensembling

Project Page | Paper | Bibtex

Ensembling with Deep Generative Views.
Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhang
CVPR 2021

Prerequisites

  • Linux
  • Python 3
  • NVIDIA GPU + CUDA CuDNN

Table of Contents:

  1. Colab - run a limited demo version without local installation
  2. Setup - download required resources
  3. Quickstart - short demonstration code snippet
  4. Notebooks - jupyter notebooks for visualization
  5. Pipeline - details on full pipeline

We project an input image into the latent space of a pre-trained GAN and perturb it slightly to obtain modifications of the input image. These alternative views from the GAN are ensembled at test-time, together with the original image, in a downstream classification task.

To synthesize deep generative views, we first align (Aligned Input) and reconstruct an image by finding the corresponding latent code in StyleGAN2 (GAN Reconstruction). We then investigate different approaches to produce image variations using the GAN, such as style-mixing on fine layers (Style-mix Fine), which predominantly changes color, or coarse layers (Style-mix Coarse), which changes pose.

Colab

This Colab Notebook demonstrates the basic latent code perturbation and classification procedure in a simplified setting on the aligned cat dataset.

Setup

  • Clone this repo:
git clone https://github.com/chail/gan-ensembling.git
cd gan-ensembling

An example of the directory organization is below:

dataset/celebahq/
	images/images/
		000004.png
		000009.png
		000014.png
		...
	latents/
	latents_idinvert/
dataset/cars/
	devkit/
		cars_meta.mat
		cars_test_annos.mat
		cars_train_annos.mat
		...
	images/images/
		00001.jpg
		00002.jpg
		00003.jpg
		...
	latents/
dataset/catface/
	images/
	latents/
dataset/cifar10/
	cifar-10-batches-py/
	latents/

Quickstart

Once the datasets and precomputed resources are downloaded, the following code snippet demonstrates how to perturb GAN images. Additional examples are contained in notebooks/demo.ipynb.

import data
from networks import domain_generator

dataset_name = 'celebahq'
generator_name = 'stylegan2'
attribute_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', attribute_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name)

index = 100
original_image = dset[index][0][None].cuda()
latent = dset[index][1][None].cuda()
gan_reconstruction = generator.decode(latent)
mix_latent = generator.seed2w(n=4, seed=0)
perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent, n=4)

Notebooks

Important: First, set up symlinks required for notebooks: bash notebooks/setup_notebooks.sh, and add the conda environment to jupyter kernels: python -m ipykernel install --user --name gan-ensembling.

The provided notebooks are:

  1. notebooks/demo.ipynb: basic usage example
  2. notebooks/evaluate_ensemble.ipynb: plot classification test accuracy as a function of ensemble weight
  3. notebooks/plot_precomputed_evaluations.ipynb: notebook to generate figures in paper

Full Pipeline

The full pipeline contains three main parts:

  1. optimize latent codes
  2. train classifiers
  3. evaluate the ensemble of GAN-generated images.

Examples for each step of the pipeline are contained in the following scripts:

bash scripts/optimize_latent/examples.sh
bash scripts/train_classifier/examples.sh
bash scripts/eval_ensemble/examples.sh

To add to the pipeline:

  • Data: in the data/ directory, add the dataset in data/__init__.py and create the dataset class and transformation functions. See data/data_*.py for examples.
  • Generator: modify networks/domain_generators.py to add the generator in domain_generators.define_generator. The perturbation ranges for each dataset and generator are specified in networks/perturb_settings.py.
  • Classifier: modify networks/domain_classifiers.py to add the classifier in domain_classifiers.define_classifier

Acknowledgements

We thank the authors of these repositories:

Citation

If you use this code for your research, please cite our paper:

@inproceedings{chai2021ensembling,
  title={Ensembling with Deep Generative Views.},
  author={Chai, Lucy and Zhu, Jun-Yan and Shechtman, Eli and Isola, Phillip and Zhang, Richard},
  booktitle={CVPR},
  year={2021}
 }
Owner
Lucy Chai
Lucy Chai
Website which uses Deep Learning to generate horror stories.

Creepypasta - Text Generator Website which uses Deep Learning to generate horror stories. View Demo · View Website Repo · Report Bug · Request Feature

Dhairya Sharma 5 Oct 14, 2022
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
Individual Tree Crown classification on WorldView-2 Images using Autoencoder -- Group 9 Weak learners - Final Project (Machine Learning 2020 Course)

Created by Olga Sutyrina, Sarah Elemili, Abduragim Shtanchaev and Artur Bille Individual Tree Crown classification on WorldView-2 Images using Autoenc

2 Dec 08, 2022
Fine-tune pretrained Convolutional Neural Networks with PyTorch

Fine-tune pretrained Convolutional Neural Networks with PyTorch. Features Gives access to the most popular CNN architectures pretrained on ImageNet. A

Alex Parinov 694 Nov 23, 2022
Implementation of various Vision Transformers I found interesting

Implementation of various Vision Transformers I found interesting

Kim Seonghyeon 78 Dec 06, 2022
Distributed Deep learning with Keras & Spark

Elephas: Distributed Deep Learning with Keras & Spark Elephas is an extension of Keras, which allows you to run distributed deep learning models at sc

Max Pumperla 1.6k Jan 05, 2023
Embeds a story into a music playlist by sorting the playlist so that the order of the music follows a narrative arc.

playlist-story-builder This project attempts to embed a story into a music playlist by sorting the playlist so that the order of the music follows a n

Dylan R. Ashley 0 Oct 28, 2021
PyTorch Implementation of Unsupervised Depth Completion with Calibrated Backprojection Layers (ORAL, ICCV 2021)

Unsupervised Depth Completion with Calibrated Backprojection Layers PyTorch implementation of Unsupervised Depth Completion with Calibrated Backprojec

80 Dec 13, 2022
Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models.

WECHSEL Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models. arXiv: https://arx

Institute of Computational Perception 45 Dec 29, 2022
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
Multi Task RL Baselines

MTRL Multi Task RL Algorithms Contents Introduction Setup Usage Documentation Contributing to MTRL Community Acknowledgements Introduction M

Facebook Research 171 Jan 09, 2023
Scripts of Machine Learning Algorithms from Scratch. Implementations of machine learning models and algorithms using nothing but NumPy with a focus on accessibility. Aims to cover everything from basic to advance.

Algo-ScriptML Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The goal of this project is not t

Algo Phantoms 81 Nov 26, 2022
Code for the paper "Learning-Augmented Algorithms for Online Steiner Tree"

Learning-Augmented Algorithms for Online Steiner Tree This is the code for the paper "Learning-Augmented Algorithms for Online Steiner Tree". Requirem

0 Dec 09, 2021
Lightweight Python library for adding real-time object tracking to any detector.

Norfair is a customizable lightweight Python library for real-time 2D object tracking. Using Norfair, you can add tracking capabilities to any detecto

Tryolabs 1.7k Jan 05, 2023
Time Series Cross-Validation -- an extension for scikit-learn

TSCV: Time Series Cross-Validation This repository is a scikit-learn extension for time series cross-validation. It introduces gaps between the traini

Wenjie Zheng 222 Jan 01, 2023
Library for converting from RGB / GrayScale image to base64 and back.

Library for converting RGB / Grayscale numpy images from to base64 and back. Installation pip install -U image_to_base_64 Conversion RGB to base 64 b

Vladimir Iglovikov 16 Aug 28, 2022
Byzantine-robust decentralized learning via self-centered clipping

Byzantine-robust decentralized learning via self-centered clipping In this paper, we study the challenging task of Byzantine-robust decentralized trai

EPFL Machine Learning and Optimization Laboratory 4 Aug 27, 2022
FAIR's research platform for object detection research, implementing popular algorithms like Mask R-CNN and RetinaNet.

Detectron is deprecated. Please see detectron2, a ground-up rewrite of Detectron in PyTorch. Detectron Detectron is Facebook AI Research's software sy

Facebook Research 25.5k Jan 07, 2023
Lite-HRNet: A Lightweight High-Resolution Network

LiteHRNet Benchmark 🔥 🔥 Based on MMsegmentation 🔥 🔥 Cityscapes FCN resize concat config mIoU last mAcc last eval last mIoU best mAcc best eval bes

16 Dec 12, 2022
A C implementation for creating 2D voronoi diagrams

Branch OSX/Linux Windows master dev jc_voronoi A fast C/C++ header only implementation for creating 2D Voronoi diagrams from a point set Uses Fortune'

Mathias Westerdahl 481 Dec 29, 2022