The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

Overview

Improved Techniques for Training Score-Based Generative Models

This repo contains the official implementation for the paper Improved Techniques for Training Score-Based Generative Models.

by Yang Song and Stefano Ermon, Stanford AI Lab.

Note: The method has been extended by the subsequent work Score-Based Generative Modeling through Stochastic Differential Equations (code) that allows better sample quality and exact log-likelihood computation.


We significantly improve the method proposed in Generative Modeling by Estimating Gradients of the Data Distribution. Score-based generative models are flexible neural networks trained to capture the score function of an underlying data distribution—a vector field pointing to directions where the data density increases most rapidly. We present new techniques to improve the performance of score-based generative models, scaling them to high resolution images that are previously impossible. Without requiring adversarial training, they can produce sharp and diverse image samples that rival GANs.

samples

(From left to right: Our samples on FFHQ 256px, LSUN bedroom 128px, LSUN tower 128px, LSUN church_outdoor 96px, and CelebA 64px.)

Running Experiments

Dependencies

Run the following to install all necessary python packages for our code.

pip install -r requirements.txt

Project structure

main.py is the file that you should run for both training and sampling. Execute python main.py --help to get its usage description:

usage: main.py [-h] --config CONFIG [--seed SEED] [--exp EXP] --doc DOC
               [--comment COMMENT] [--verbose VERBOSE] [--test] [--sample]
               [--fast_fid] [--resume_training] [-i IMAGE_FOLDER] [--ni]

optional arguments:
  -h, --help            show this help message and exit
  --config CONFIG       Path to the config file
  --seed SEED           Random seed
  --exp EXP             Path for saving running related data.
  --doc DOC             A string for documentation purpose. Will be the name
                        of the log folder.
  --comment COMMENT     A string for experiment comment
  --verbose VERBOSE     Verbose level: info | debug | warning | critical
  --test                Whether to test the model
  --sample              Whether to produce samples from the model
  --fast_fid            Whether to do fast fid test
  --resume_training     Whether to resume training
  -i IMAGE_FOLDER, --image_folder IMAGE_FOLDER
                        The folder name of samples
  --ni                  No interaction. Suitable for Slurm Job launcher

Configuration files are in config/. You don't need to include the prefix config/ when specifying --config . All files generated when running the code is under the directory specified by --exp. They are structured as:

<exp> # a folder named by the argument `--exp` given to main.py
├── datasets # all dataset files
├── logs # contains checkpoints and samples produced during training
│   └── <doc> # a folder named by the argument `--doc` specified to main.py
│      ├── checkpoint_x.pth # the checkpoint file saved at the x-th training iteration
│      ├── config.yml # the configuration file for training this model
│      ├── stdout.txt # all outputs to the console during training
│      └── samples # all samples produced during training
├── fid_samples # contains all samples generated for fast fid computation
│   └── <i> # a folder named by the argument `-i` specified to main.py
│      └── ckpt_x # a folder of image samples generated from checkpoint_x.pth
├── image_samples # contains generated samples
│   └── <i>
│       └── image_grid_x.png # samples generated from checkpoint_x.pth       
└── tensorboard # tensorboard files for monitoring training
    └── <doc> # this is the log_dir of tensorboard

Training

For example, we can train an NCSNv2 on LSUN bedroom by running the following

python main.py --config bedroom.yml --doc bedroom

Log files will be saved in <exp>/logs/bedroom.

Sampling

If we want to sample from NCSNv2 on LSUN bedroom, we can edit bedroom.yml to specify the ckpt_id under the group sampling, and then run the following

python main.py --sample --config bedroom.yml -i bedroom

Samples will be saved in <exp>/image_samples/bedroom.

We can interpolate between different samples (see more details in the paper). Just set interpolation to true and an appropriate n_interpolations under the group of sampling in bedroom.yml. We can also perform other tasks such as inpainting. Usages should be quite obvious if you read the code and configuration files carefully.

Computing FID values quickly for a range of checkpoints

We can specify begin_ckpt and end_ckpt under the fast_fid group in the configuration file. For example, by running the following command, we can generate a small number of samples per checkpoint within the range begin_ckpt-end_ckpt for a quick (and rough) FID evaluation.

python main.py --fast_fid --config bedroom.yml -i bedroom

You can find samples in <exp>/fid_samples/bedroom.

Pretrained Checkpoints

Link: https://drive.google.com/drive/folders/1217uhIvLg9ZrYNKOR3XTRFSurt4miQrd?usp=sharing

You can produce samples using it on all datasets we tested in the paper. It assumes the --exp argument is set to exp.

References

If you find the code/idea useful for your research, please consider citing

@inproceedings{song2020improved,
  author    = {Yang Song and Stefano Ermon},
  editor    = {Hugo Larochelle and
               Marc'Aurelio Ranzato and
               Raia Hadsell and
               Maria{-}Florina Balcan and
               Hsuan{-}Tien Lin},
  title     = {Improved Techniques for Training Score-Based Generative Models},
  booktitle = {Advances in Neural Information Processing Systems 33: Annual Conference
               on Neural Information Processing Systems 2020, NeurIPS 2020, December
               6-12, 2020, virtual},
  year      = {2020}
}

and/or our previous work

@inproceedings{song2019generative,
  title={Generative Modeling by Estimating Gradients of the Data Distribution},
  author={Song, Yang and Ermon, Stefano},
  booktitle={Advances in Neural Information Processing Systems},
  pages={11895--11907},
  year={2019}
}
PyTorch implementation of Constrained Policy Optimization

PyTorch implementation of Constrained Policy Optimization (CPO) This repository has a simple to understand and use implementation of CPO in PyTorch. A

Sapana Chaudhary 25 Dec 08, 2022
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
Blender Add-On for slicing meshes with planes

MeshSlicer Blender Add-On for slicing meshes with multiple overlapping planes at once. This is a simple Blender addon to slice a silmple mesh with mul

52 Dec 12, 2022
Location-Sensitive Visual Recognition with Cross-IOU Loss

The trained models are temporarily unavailable, but you can train the code using reasonable computational resource. Location-Sensitive Visual Recognit

Kaiwen Duan 146 Dec 25, 2022
Mind the Trade-off: Debiasing NLU Models without Degrading the In-distribution Performance

Models for natural language understanding (NLU) tasks often rely on the idiosyncratic biases of the dataset, which make them brittle against test cases outside the training distribution.

Ubiquitous Knowledge Processing Lab 22 Jan 02, 2023
MQBench Quantization Aware Training with PyTorch

MQBench Quantization Aware Training with PyTorch I am using MQBench(Model Quantization Benchmark)(http://mqbench.tech/) to quantize the model for depl

Ling Zhang 29 Nov 18, 2022
Transformer in Vision

Transformer-in-Vision Recent Transformer-based CV and related works. Welcome to comment/contribute! Keep updated. Resource SCENIC: A JAX Library for C

Yong-Lu Li 1.1k Dec 30, 2022
Object tracking implemented with YOLOv4, DeepSort, and TensorFlow.

Object tracking implemented with YOLOv4, DeepSort, and TensorFlow. YOLOv4 is a state of the art algorithm that uses deep convolutional neural networks to perform object detections. We can take the ou

The AI Guy 1.1k Dec 29, 2022
TensorFlow CNN for fast style transfer

Fast Style Transfer in TensorFlow Add styles from famous paintings to any photo in a fraction of a second! It takes 100ms on a 2015 Titan X to style t

1 Dec 14, 2021
Asterisk is a framework to generate high-quality training datasets at scale

Asterisk is a framework to generate high-quality training datasets at scale

Mona Nashaat 44 Apr 25, 2022
PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Out-of-distribution Generalization Investigation on Vision Transformers This repository contains PyTorch evaluation code for Delving Deep into the Gen

Chongzhi Zhang 72 Dec 13, 2022
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
The Habitat-Matterport 3D Research Dataset - the largest-ever dataset of 3D indoor spaces.

Habitat-Matterport 3D Dataset (HM3D) The Habitat-Matterport 3D Research Dataset is the largest-ever dataset of 3D indoor spaces. It consists of 1,000

Meta Research 62 Dec 27, 2022
A collection of Jupyter notebooks to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

StyleGAN3 CLIP-based guidance StyleGAN3 + CLIP StyleGAN3 + inversion + CLIP This repo is a collection of Jupyter notebooks made to easily play with St

Eugenio Herrera 176 Dec 30, 2022
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

Samantha Van Seters 1 Jan 10, 2022
High-fidelity 3D Model Compression based on Key Spheres

High-fidelity 3D Model Compression based on Key Spheres This repository contains the implementation of the paper: Yuanzhan Li, Yuqi Liu, Yujie Lu, Siy

5 Oct 11, 2022
Unsupervised MRI Reconstruction via Zero-Shot Learned Adversarial Transformers

Official TensorFlow implementation of the unsupervised reconstruction model using zero-Shot Learned Adversarial TransformERs (SLATER). (https://arxiv.

ICON Lab 22 Dec 22, 2022
An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Transformer-in-Transformer An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local pa

Rishit Dagli 40 Jul 25, 2022
Automatic Number Plate Recognition using Contours and Convolution Neural Networks (CNN)

Cite our paper if you find this project useful https://www.ijariit.com/manuscripts/v7i4/V7I4-1139.pdf Abstract Image processing technology is used in

Adithya M 2 Jun 28, 2022
An off-line judger supporting distributed problem repositories

Thaw 中文 | English Thaw is an off-line judger supporting distributed problem repositories. Everyone can use Thaw release problems with license on GitHu

countercurrent_time 2 Jan 09, 2022