Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Overview

Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Paper

Description

Recent research has shown that numerous human-interpretable directions exist in the latent space of GANs. In this paper, we develop an automatic procedure for finding directions that lead to foreground-background image separation, and we use these directions to train an image segmentation model without human supervision. Our method is generator-agnostic, producing strong segmentation results with a wide range of different GAN architectures. Furthermore, by leveraging GANs pretrained on large datasets such as ImageNet, we are able to segment images from a range of domains without further training or finetuning. Evaluating our method on image segmentation benchmarks, we compare favorably to prior work while using neither human supervision nor access to the training data. Broadly, our results demonstrate that automatically extracting foreground-background structure from pretrained deep generative models can serve as a remarkably effective substitute for human supervision.

How to run

Dependencies

This code depends on pytorch-pretrained-gans, a repository I developed that exposes a standard interface for a variety of pretrained GANs. Install it with:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

The pretrained weights for most GANs are downloaded automatically. For those that are not, I have provided scripts in that repository.

There are also some standard dependencies:

Install them with:

pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia

General Approach

Our unsupervised segmentation approach has two steps: (1) finding a good direction in latent space, and (2) training a segmentation model from data and masks that are generated using this direction.

In detail, this means:

  1. We use optimization/main.py finds a salient direction (or two salient directions) in the latent space of a given pretrained GAN that leads to foreground-background image separation.
  2. We use segmentation/main.py to train a standard segmentation network (a UNet) on generated data. The data can be generated in two ways: (1) you can generate the images on-the-fly during training, or (2) you can generate the images before training the segmentation model using segmentation/generate_and_save.py and then train the segmentation network afterward. The second approach is faster, but requires more disk space (~10GB for 1 million images). We will also provide a pre-generated dataset (coming soon).

Configuration and Logging

We use Hydra for configuration and Weights and Biases for logging. With Hydra, you can specify a config file (found in configs/) with --config-name=myconfig.yaml. You can also override the config from the command line by specifying the overriding arguments (without --). For example, you can enable Weights and Biases with wandb=True and you can name the run with name=myname.

The structure of the configs is as follows:

config
├── data_gen
│   ├── generated.yaml  # <- for generating data with 1 latent direction
│   ├── generated-dual.yaml   # <- for generating data with 2 latent directions
│   ├── generator  # <- different types of GANs for generating data
│   │   ├── bigbigan.yaml
│   │   ├── pretrainedbiggan.yaml
│   │   ├── selfconditionedgan.yaml
│   │   ├── studiogan.yaml
│   │   └── stylegan2.yaml 
│   └── saved.yaml  # <- for using pre-generated data
├── optimize.yaml  # <- for optimization
└── segment.yaml   # <- for segmentation

Code Structure

The code is structured as follows:

src
├── models  # <- segmentation model
│   ├── __init__.py
│   ├── latent_shift_model.py  # <- shifts direction in latent space
│   ├── unet_model.py  # <- segmentation model
│   └── unet_parts.py
├── config  # <- configuration, explained above
│   ├── ... 
├── datasets  # <- classes for loading datasets during segmentation/generation
│   ├── __init__.py
│   ├── gan_dataset.py  # <- for generating dataset
│   ├── saved_gan_dataset.py  # <- for pre-generated dataset
│   └── real_dataset.py  # <- for evaluation datasets (i.e. real images)
├── optimization
│   ├── main.py  # <- main script
│   └── utils.py  # <- helper functions
└── segmentation
    ├── generate_and_save.py  # <- for generating a dataset and saving it to disk
    ├── main.py  # <- main script, uses PyTorch Lightning 
    ├── metrics.py  # <- for mIoU/F-score calculations
    └── utils.py  # <- helper functions

Datasets

The datasets should have the following structure. You can easily add you own datasets or use only a subset of these datasets by modifying config/segment.yaml. You should specify your directory by modifying root in that file on line 19, or by passing data_seg.root=MY_DIR using the command line whenever you call python segmentation/main.py.

├── DUT_OMRON
│   ├── DUT-OMRON-image
│   │   └── ...
│   └── pixelwiseGT-new-PNG
│       └── ...
├── DUTS
│   ├── DUTS-TE
│   │   ├── DUTS-TE-Image
│   │   │   └── ...
│   │   └── DUTS-TE-Mask
│   │       └── ...
│   └── DUTS-TR
│       ├── DUTS-TR-Image
│       │   └── ...
│       └── DUTS-TR-Mask
│           └── ...
├── ECSSD
│   ├── ground_truth_mask
│   │   └── ...
│   └── images
│       └── ...
├── CUB_200_2011
│   ├── train_images
│   │   └── ...
│   ├── train_segmentations
│   │   └── ...
│   ├── test_images
│   │   └── ...
│   └── test_segmentations
│       └── ...
└── Flowers
    ├── train_images
    │   └── ...
    ├── train_segmentations
    │   └── ...
    ├── test_images
    │   └── ...
    └── test_segmentations
        └── ...

The datasets can be downloaded from:

Training

Before training, make sure you understand the general approach (explained above).

Note: All commands are called from within the src directory.

In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to.

Optimization

PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME

This should take less than 5 minutes to run. The output will be saved in outputs/optimization/fixed-BigBiGAN-NAME/DATE/, with the final checkpoint in latest.pth.

Segmentation with precomputed generations

The recommended way of training is to generate the data first and train afterward. An example generation script would be:

PYTHONPATH=. python segmentation/generate_and_save.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.save_dir="YOUR_OUTPUT_DIR" \
data_gen.save_size=1000000 \
data_gen.kwargs.batch_size=1 \
data_gen.kwargs.generation_batch_size=128

This will generate 1 million image-label pairs and save them to YOUR_OUTPUT_DIR/images. Note that YOUR_OUTPUT_DIR should be an absolute path, not a relative one, because Hydra changes the working directory. You may also want to tune the generation_batch_size to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.

Once you have generated data, you can train a segmentation model:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=saved \
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"

It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).

Segmentation with on-the-fly generations

Alternatively, you can generate data while training the segmentation model. An example script would be:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.kwargs.generation_batch_size=128

Evaluation

To evaluate, set the train argument to False. For example:

python train.py \
name="eval" \
train=False \
eval_checkpoint=${checkpoint} \
data_seg.root=${DATASETS_DIR} 

Pretrained models

  • ... are coming soon!

Available GANs

It should be possible to use any GAN from pytorch-pretrained-gans, including:

Citation

@inproceedings{melaskyriazi2021finding,
  author    = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},
  title     = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
  booktitle = arxiv,
  year      = {2021}
}
You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. The pytorch implementation of  DG-Font: Deformable Generative Networks for Unsupervised Font Generation
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

Minimal PyTorch implementation of Generative Latent Optimization from the paper
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

DeepCAD: A Deep Generative Network for Computer-Aided Design Models
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

Comments
  • pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    Hi, is the repo in the pytorch-pretrained-gans step public or is that the right URL for it? I got prompted for username and password when I tried the pip install git+ and don't see the repo at that URL: https://github.com/lukemelas/pytorch-pretrained-gans (Get 404)

    Thanks.

    opened by ModMorph 2
  • Help producing results with the StyleGAN models

    Help producing results with the StyleGAN models

    Hi there!

    I'm having trouble producing meaningful results on StyleGAN2 on AFHQ. I've been using the default setup and hyperparameters. After 50 iterations (with the default batch size of 32) I get visualisations that look initially promising: (https://i.imgur.com/eR79Wyd.png). But as training progresses, and indeed when it reaches 300 iterations, these are the visualisation results: https://i.imgur.com/36zhBzT.png.

    I've tried playing with the learning rate, and the number of iterations with no success yet. Did you have tips here or ideas as to what might be going wrong here?

    Thanks! James.

    opened by james-oldfield 1
  • bug

    bug

    Firstly, I ran PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME. And then, I ran PYTHONPATH=. python segmentation/generate_and_save.py \ name=NAME \ data_gen=generated \ data_gen/generator=bigbigan \ data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \ data_gen.save_dir="YOUR_OUTPUT_DIR" \ data_gen.save_size=1000000 \ data_gen.kwargs.batch_size=1 \ data_gen.kwargs.generation_batch_size=128 When I ran PYTHONPATH=. python segmentation/main.py \ name=NAME \ data_gen=saved \ data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE" An error occurred. The error is: Traceback (most recent call last): File "segmentation/main.py", line 98, in main kwargs = dict(images_dir=_cfg.images_dir, labels_dir=_cfg.labels_dir, omegaconf.errors.InterpolationResolutionError: KeyError raised while resolving interpolation: "Environment variable '/raid/name/gaochengli/segmentation/src/images' not found" full_key: data_seg.data[0].images_dir object_type=dict According to what you wrote, I modified the root (config/segment.yaml on line 19). Just like this "/raid/name/gaochengli/segmentation/src/images". And the folder contains all data sets,whose name is images. I wonder why such a mistake happened.

    opened by Lee-Gao 1
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork 👀 : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
Official code repository for Continual Learning In Environments With Polynomial Mixing Times

Official code for Continual Learning In Environments With Polynomial Mixing Times Continual Learning in Environments with Polynomial Mixing Times This

Sharath Raparthy 1 Dec 19, 2021
Cross-Modal Contrastive Learning for Text-to-Image Generation

Cross-Modal Contrastive Learning for Text-to-Image Generation This repository hosts the open source JAX implementation of XMC-GAN. Setup instructions

Google Research 94 Nov 12, 2022
StyleGAN - Official TensorFlow Implementation

StyleGAN — Official TensorFlow Implementation Picture: These people are not real – they were produced by our generator that allows control over differ

NVIDIA Research Projects 13.1k Jan 09, 2023
Code release for "Masked-attention Mask Transformer for Universal Image Segmentation"

Mask2Former: Masked-attention Mask Transformer for Universal Image Segmentation Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Ro

Meta Research 1.2k Jan 02, 2023
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
MetaTTE: a Meta-Learning Based Travel Time Estimation Model for Multi-city Scenarios

MetaTTE: a Meta-Learning Based Travel Time Estimation Model for Multi-city Scenarios This is the official TensorFlow implementation of MetaTTE in the

morningstarwang 4 Dec 14, 2022
HistoKT: Cross Knowledge Transfer in Computational Pathology

HistoKT: Cross Knowledge Transfer in Computational Pathology Exciting News! HistoKT has been accepted to ICASSP 2022. HistoKT: Cross Knowledge Transfe

Mahdi S. Hosseini 5 Jan 05, 2023
Advancing mathematics by guiding human intuition with AI

Advancing mathematics by guiding human intuition with AI This repo contains two colab notebooks which accompany the paper, available online at https:/

DeepMind 315 Dec 26, 2022
GeDML is an easy-to-use generalized deep metric learning library

GeDML is an easy-to-use generalized deep metric learning library

Borui Zhang 32 Dec 05, 2022
Unofficial implementation (replicates paper results!) of MINER: Multiscale Implicit Neural Representations in pytorch-lightning

MINER_pl Unofficial implementation of MINER: Multiscale Implicit Neural Representations in pytorch-lightning. 📖 Ref readings Laplacian pyramid explan

AI葵 51 Nov 28, 2022
Azion the best solution of Edge Computing in the world.

Azion Edge Function docker action Create or update an Edge Functions on Azion Edge Nodes. The domain name is the key for decision to a create or updat

8 Jul 16, 2022
Codes for “A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing Change Detection”

DSAMNet The pytorch implementation for "A Deeply-supervised Attention Metric-based Network and an Open Aerial Image Dataset for Remote Sensing Change

Mengxi Liu 41 Dec 14, 2022
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
Encode and decode text application

Text Encoder and Decoder Encode and decode text in many ways using this application! Encode in: ASCII85 Base85 Base64 Base32 Base16 Url MD5 Hash SHA-1

Alice 1 Feb 12, 2022
68 keypoint annotations for COFW test data

68 keypoint annotations for COFW test data This repository contains manually annotated 68 keypoints for COFW test data (original annotation of CFOW da

31 Dec 06, 2022
UT-Sarulab MOS prediction system using SSL models

UTMOS: UTokyo-SaruLab MOS Prediction System Official implementation of "UTMOS: UTokyo-SaruLab System for VoiceMOS Challenge 2022" submitted to INTERSP

sarulab-speech 58 Nov 22, 2022
Massively parallel Monte Carlo diffusion MR simulator written in Python.

Disimpy Disimpy is a Python package for generating simulated diffusion-weighted MR signals that can be useful in the development and validation of dat

Leevi 16 Nov 11, 2022
IMBENS: class-imbalanced ensemble learning in Python.

IMBENS: class-imbalanced ensemble learning in Python. Links: [Documentation] [Gallery] [PyPI] [Changelog] [Source] [Download] [知乎/Zhihu] [中文README] [a

Zhining Liu 176 Jan 04, 2023
A library of scripts that interact with the PythonTurtle module to create games, drawings, and more

TurtleLib TurtleLib is a library of scripts that interact with the PythonTurtle module to create games, drawings, and more! Using the Scripts Copy or

1 Jan 15, 2022