[CVPR 2022] Unsupervised Image-to-Image Translation with Generative Prior

Overview

GP-UNIT - Official PyTorch Implementation

This repository provides the official PyTorch implementation for the following paper:

Unsupervised Image-to-Image Translation with Generative Prior
Shuai Yang, Liming Jiang, Ziwei Liu and Chen Change Loy
In CVPR 2022.
Project Page | Paper | Supplementary Video

Abstract: Unsupervised image-to-image translation aims to learn the translation between two visual domains without paired data. Despite the recent progress in image translation models, it remains challenging to build mappings between complex domains with drastic visual discrepancies. In this work, we present a novel framework, Generative Prior-guided UNsupervised Image-to-image Translation (GP-UNIT), to improve the overall quality and applicability of the translation algorithm. Our key insight is to leverage the generative prior from pre-trained class-conditional GANs (e.g., BigGAN) to learn rich content correspondences across various domains. We propose a novel coarse-to-fine scheme: we first distill the generative prior to capture a robust coarse-level content representation that can link objects at an abstract semantic level, based on which fine-level content features are adaptively learned for more accurate multi-level content correspondences. Extensive experiments demonstrate the superiority of our versatile framework over state-of-the-art methods in robust, high-quality and diversified translations, even for challenging and distant domains.

Updates

  • [03/2022] Paper and supplementary video are released.
  • [04/2022] Code and dataset are released.
  • [03/2022] This website is created.

Installation

Clone this repo:

git clone https://github.com/williamyang1991/GP-UNIT.git
cd GP-UNIT

Dependencies:

We have tested on:

  • CUDA 10.1
  • PyTorch 1.7.0
  • Pillow 8.0.1; Matplotlib 3.3.3; opencv-python 4.4.0; Faiss 1.7.0; tqdm 4.54.0

All dependencies for defining the environment are provided in environment/gpunit_env.yaml. We recommend running this repository using Anaconda:

conda env create -f ./environment/gpunit_env.yaml

We use CUDA 10.1 so it will install PyTorch 1.7.0 (corresponding to Line 16, Line 113, Line 120, Line 121 of gpunit_env.yaml). Please install PyTorch that matches your own CUDA version following https://pytorch.org/.


(1) Dataset Preparation

Human face dataset, animal face dataset and aristic human face dataset can be downloaded from their official pages. Bird, dog and car datasets can be built from ImageNet with our provided script.

Task Used Dataset
Male←→Female CelebA-HQ: divided into male and female subsets by StarGANv2
Dog←→Cat←→Wild AFHQ provided by StarGANv2
Face←→Cat or Dog CelebA-HQ and AFHQ
Bird←→Dog 4 classes of birds and 4 classes of dogs in ImageNet291. Please refer to dataset preparation for building ImageNet291 from ImageNet
Bird←→Car 4 classes of birds and 4 classes of cars in ImageNet291. Please refer to dataset preparation for building ImageNet291 from ImageNet
Face→MetFace CelebA-HQ and MetFaces

(2) Inference for Latent-Guided and Exemplar-Guided Translation

Inference Notebook


To help users get started, we provide a Jupyter notebook at ./notebooks/inference_playground.ipynb that allows one to visualize the performance of GP-UNIT. The notebook will download the necessary pretrained models and run inference on the images in ./data/.

Web Demo

Try Replicate web demo here Replicate

Pretrained Models

Pretrained models can be downloaded from Google Drive or Baidu Cloud (access code: cvpr):

Task Pretrained Models
Prior Distillation content encoder
Male←→Female generators for male2female and female2male
Dog←→Cat←→Wild generators for dog2cat, cat2dog, dog2wild, wild2dog, cat2wild and wild2cat
Face←→Cat or Dog generators for face2cat, cat2face, dog2face and face2dog
Bird←→Dog generators for bird2dog and dog2bird
Bird←→Car generators for bird2car and car2bird
Face→MetFace generator for face2metface

The saved checkpoints are under the following folder structure:

checkpoint
|--content_encoder.pt     % Content encoder
|--bird2car.pt            % Bird-to-Car translation model
|--bird2dog.pt            % Bird-to-Dog translation model
...

Latent-Guided Translation

Translate a content image to the target domain with randomly sampled latent styles:

python inference.py --generator_path PRETRAINED_GENERATOR_PATH --content_encoder_path PRETRAINED_ENCODER_PATH \ 
                    --content CONTENT_IMAGE_PATH --batch STYLE_NUMBER --device DEVICE

By default, the script will use .\checkpoint\dog2cat.pt as PRETRAINED_GENERATOR_PATH, .\checkpoint\content_encoder.pt as PRETRAINED_ENCODER_PATH, and cuda as DEVICE for using GPU. For running on CPUs, use --device cpu.

Take Dog→Cat as an example, run:

python inference.py --content ./data/afhq/images512x512/test/dog/flickr_dog_000572.jpg --batch 6

Six results translation_flickr_dog_000572_N.jpg (N=0~5) are saved in the folder .\output\. An corresponding overview image translation_flickr_dog_000572_overview.jpg is additionally saved to illustrate the input content image and the six results:

Evaluation Metrics: We use the code of StarGANv2 to calculate FID and Diversity with LPIPS in our paper.

Exemplar-Guided Translation

Translate a content image to the target domain in the style of a style image by additionally specifying --style:

python inference.py --generator_path PRETRAINED_GENERATOR_PATH --content_encoder_path PRETRAINED_ENCODER_PATH \ 
                    --content CONTENT_IMAGE_PATH --style STYLE_IMAGE_PATH --device DEVICE

Take Dog→Cat as an example, run:

python inference.py --content ./data/afhq/images512x512/test/dog/flickr_dog_000572.jpg --style ./data/afhq/images512x512/test/cat/flickr_cat_000418.jpg

The result translation_flickr_dog_000572_to_flickr_cat_000418.jpg is saved in the folder .\output\. An corresponding overview image translation_flickr_dog_000572_to_flickr_cat_000418_overview.jpg is additionally saved to illustrate the input content image, the style image, and the result:

Another example of Cat→Wild, run:

python inference.py --generator_path ./checkpoint/cat2wild.pt --content ./data/afhq/images512x512/test/cat/flickr_cat_000418.jpg --style ./data/afhq/images512x512/test/wild/flickr_wild_001112.jpg

The overview image is as follows:


(3) Training GP-UNIT

Download the supporting models to the ./checkpoint/ folder:

Model Description
content_encoder.pt Our pretrained content encoder which distills BigGAN prior from the synImageNet291 dataset.
model_ir_se50.pth Pretrained IR-SE50 model taken from TreB1eN for ID loss.

Train Image-to-Image Transaltion Network

python train.py --task TASK --batch BATCH_SIZE --iter ITERATIONS \
                --source_paths SPATH1 SPATH2 ... SPATHS --source_num SNUM1 SNUM2 ... SNUMS \
                --target_paths TPATH1 TPATH2 ... TPATHT --target_num TNUM1 TNUM2 ... TNUMT

where SPATH1~SPATHS are paths to S folders containing images from the source domain (e.g., S classes of ImageNet birds), SNUMi is the number of images in SPATHi used for training. TPATHi, TNUMi are similarily defined but for the target domain. By default, BATCH_SIZE=16 and ITERATIONS=75000. If --source_num/--target_num is not specified, all images in the folders are used.

The trained model is saved as ./checkpoint/TASK-ITERATIONS.pt. Intermediate results are saved in ./log/TASK/.

This training does not necessarily lead to the optimal results, which can be further customized with additional command line options:

  • --style_layer (default: 4): the discriminator layer to compute the feature matching loss. We found setting style_layer=5 gives better performance on human faces.
  • --use_allskip (default: False): whether using dynamic skip connections to compute the reconstruction loss. For tasks involving close domains like gender translation, season transfer and face stylization, using use_allskip gives better results.
  • --use_idloss (default: False): whether using the identity loss. For Cat/Dog→Face and Face→MetFace tasks, we use this loss.
  • --not_flip_style (default: False): whether not randomly flipping the style image when extracting the style feature. Random flipping prevents the network to learn position information from the style image.
  • --mitigate_style_bias(default: False): whether resampling style features when training the sampling network. For imbalanced dataset that has minor groups, mitigate_style_bias oversamples those style features that are far from the mean style feature of the whole dataset. This leads to more diversified latent-guided translation at the cost of slight image quality degradation. We use it on CelebA-HQ and AFHQ-related tasks.

Here are some examples:
(Parts of our tasks require the ImageNet291 dataset. Please refer to data preparation)

Male→Female

python train.py --task male2female --source_paths ./data/celeba_hq/train/male --target_paths ./data/celeba_hq/train/female --style_layer 5 --mitigate_style_bias --use_allskip --not_flip_style

Cat→Dog

python train.py --task cat2dog --source_paths ./data/afhq/images512x512/train/cat --source_num 4000 --target_paths ./data/afhq/images512x512/train/dog --target_num 4000 --mitigate_style_bias

Cat→Face

python train.py --task cat2face --source_paths ./data/afhq/images512x512/train/cat --source_num 4000 --target_paths ./data/ImageNet291/train/1001_face/ --style_layer 5 --mitigate_style_bias --not_flip_style --use_idloss

Bird→Car (translating 4 classes of birds to 4 classes of cars)

python train.py --task bird2car --source_paths ./data/ImageNet291/train/10_bird/ ./data/ImageNet291/train/11_bird/ ./data/ImageNet291/train/12_bird/ ./data/ImageNet291/train/13_bird/ --source_num 600 600 600 600 --target_paths ./data/ImageNet291/train/436_vehicle/ ./data/ImageNet291/train/511_vehicle/ ./data/ImageNet291/train/627_vehicle/ ./data/ImageNet291/train/656_vehicle/ --target_num 600 600 600 600

Train Content Encoder of Prior Distillation

We provide our pretrained model content_encoder.pt at Google Drive or Baidu Cloud (access code: cvpr). This model is obtained by:

python prior_distillation.py --unpaired_data_root ./data/ImageNet291/train/ --paired_data_root ./data/synImageNet291/train/ --unpaired_mask_root ./data/ImageNet291_mask/train/ --paired_mask_root ./data/synImageNet291_mask/train/

The training requires ImageNet291 and synImageNet291 datasets. Please refer to data preparation.


Results

Male-to-Female: close domains

male2female

Cat-to-Dog: related domains

cat2dog

Dog-to-Human and Bird-to-Dog: distant domains

dog2human

bird2dog

Bird-to-Car: extremely distant domains for stress testing

bird2car

Citation

If you find this work useful for your research, please consider citing our paper:

@inproceedings{yang2022Unsupervised,
  title={Unsupervised Image-to-Image Translation with Generative Prior},
  author={Yang, Shuai and Jiang, Liming and Liu, Ziwei and Loy, Chen Change},
  booktitle={CVPR},
  year={2022}
}

Acknowledgments

The code is developed based on StarGAN v2, SPADE and Imaginaire.

Owner
particle tracking model, works with the ROMS output file(qck.nc, his.nc)

particle-tracking-model-for-ROMS particle tracking model, works with the ROMS output file(qck.nc, his.nc) description this is a 2-dimensional particle

xusheng 1 Jan 11, 2022
Multi-task yolov5 with detection and segmentation based on yolov5

YOLOv5DS Multi-task yolov5 with detection and segmentation based on yolov5(branch v6.0) decoupled head anchor free segmentation head README中文 Ablation

150 Dec 30, 2022
Flexible time series feature extraction & processing

tsflex is a toolkit for flexible time series processing & feature extraction, that is efficient and makes few assumptions about sequence data. Useful

PreDiCT.IDLab 206 Dec 28, 2022
YOLOv5 Series Multi-backbone, Pruning and quantization Compression Tool Box.

YOLOv5-Compression Update News Requirements 环境安装 pip install -r requirements.txt Evaluation metric Visdrone Model mAP ZhangYuan 719 Jan 02, 2023

Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Kim Seonghyeon 502 Jan 03, 2023
All the essential resources and template code needed to understand and practice data structures and algorithms in python with few small projects to demonstrate their practical application.

Data Structures and Algorithms Python INDEX 1. Resources - Books Data Structures - Reema Thareja competitiveCoding Big-O Cheat Sheet DAA Syllabus Inte

Shushrut Kumar 129 Dec 15, 2022
Reinforcement Learning Theory Book (rus)

Reinforcement Learning Theory Book (rus)

qbrick 206 Nov 27, 2022
Source codes of CenterTrack++ in 2021 ICME Workshop on Big Surveillance Data Processing and Analysis

MOT Tracked object bounding box association (CenterTrack++) New association method based on CenterTrack. Two new branches (Tracked Size and IOU) are a

36 Oct 04, 2022
Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation (CoRL 2021)

Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation [Project website] [Paper] This project is a PyTorch i

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 6 Feb 28, 2022
Attempt at implementation of a simple GAN using Keras

Simple GAN This is my attempt to make a wrapper class for a GAN in keras which can be used to abstract the whole architecture process. Simple GAN Over

Deven96 7 May 23, 2019
JudeasRx - graphical app for doing personalized causal medicine using the methods invented by Judea Pearl et al.

JudeasRX Instructions Read the references given in the Theory and Notation section below Fire up the Jupyter Notebook judeas-rx.ipynb The notebook dra

Robert R. Tucci 19 Nov 07, 2022
DeepDiffusion: Unsupervised Learning of Retrieval-adapted Representations via Diffusion-based Ranking on Latent Feature Manifold

DeepDiffusion Introduction This repository provides the code of the DeepDiffusion algorithm for unsupervised learning of retrieval-adapted representat

4 Nov 15, 2022
[ICCV 2021] Counterfactual Attention Learning for Fine-Grained Visual Categorization and Re-identification

Counterfactual Attention Learning Created by Yongming Rao*, Guangyi Chen*, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for ICCV

Yongming Rao 90 Dec 31, 2022
In real-world applications of machine learning, reliable and safe systems must consider measures of performance beyond standard test set accuracy

PixMix Introduction In real-world applications of machine learning, reliable and safe systems must consider measures of performance beyond standard te

Andy Zou 79 Dec 30, 2022
An open-source Deep Learning Engine for Healthcare that aims to treat & prevent major diseases

AlphaCare Background AlphaCare is a work-in-progress, open-source Deep Learning Engine for Healthcare that aims to treat and prevent major diseases. T

Siraj Raval 44 Nov 05, 2022
Code for "AutoMTL: A Programming Framework for Automated Multi-Task Learning"

AutoMTL: A Programming Framework for Automated Multi-Task Learning This is the website for our paper "AutoMTL: A Programming Framework for Automated M

Ivy Zhang 40 Dec 04, 2022
[ICCV 2021] FaPN: Feature-aligned Pyramid Network for Dense Image Prediction

FaPN: Feature-aligned Pyramid Network for Dense Image Prediction [arXiv] [Project Page] @inproceedings{ huang2021fapn, title={{FaPN}: Feature-alig

EMI-Group 175 Dec 30, 2022
GDR-Net: Geometry-Guided Direct Regression Network for Monocular 6D Object Pose Estimation. (CVPR 2021)

GDR-Net This repo provides the PyTorch implementation of the work: Gu Wang, Fabian Manhardt, Federico Tombari, Xiangyang Ji. GDR-Net: Geometry-Guided

169 Jan 07, 2023
A PyTorch implementation of the paper "Semantic Image Synthesis via Adversarial Learning" in ICCV 2017

Semantic Image Synthesis via Adversarial Learning This is a PyTorch implementation of the paper Semantic Image Synthesis via Adversarial Learning. Req

Seonghyeon Nam 146 Nov 25, 2022
FedTorch is an open-source Python package for distributed and federated training of machine learning models using PyTorch distributed API

FedTorch is a generic repository for benchmarking different federated and distributed learning algorithms using PyTorch Distributed API.

Machine Learning and Optimization Lab @PennState 136 Dec 23, 2022