Codebase for Diffusion Models Beat GANS on Image Synthesis.

Overview

guided-diffusion

This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

This repository is based on openai/improved-diffusion, with modifications for classifier conditioning and architecture improvements.

Download pre-trained models

We have released checkpoints for the main models in the paper. Before using these models, please review the corresponding model card to understand the intended use and limitations of these models.

Here are the download links for each model checkpoint:

Sampling from pre-trained models

To sample from these models, you can use the classifier_sample.py, image_sample.py, and super_res_sample.py scripts. Here, we provide flags for sampling from all of these models. We assume that you have downloaded the relevant model checkpoints into a folder called models/.

For these examples, we will generate 100 samples with batch size 4. Feel free to change these values.

SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 250"

Classifier guidance

Note for these sampling runs that you can set --classifier_scale 0 to sample from the base diffusion model. You may also use the image_sample.py script instead of classifier_sample.py in that case.

  • 64x64 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 1.0 --classifier_path models/64x64_classifier.pt --model_path models/64x64_diffusion.pt $SAMPLE_FLAGS
  • 128x128 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 0.5 --classifier_path models/128x128_classifier.pt --model_path models/128x128_diffusion.pt $SAMPLE_FLAGS
  • 256x256 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 1.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS
  • 256x256 model (unconditional):
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS
  • 512x512 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 512 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 4.0 --classifier_path models/512x512_classifier.pt --model_path models/512x512_diffusion.pt $SAMPLE_FLAGS

Upsampling

For these runs, we assume you have some base samples in a file 64_samples.npz or 128_samples.npz for the two respective models.

  • 64 -> 256:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --large_size 256  --small_size 64 --learn_sigma True --noise_schedule linear --num_channels 192 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python super_res_sample.py $MODEL_FLAGS --model_path models/64_256_upsampler.pt --base_samples 64_samples.npz $SAMPLE_FLAGS
  • 128 -> 512:
MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 512 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python super_res_sample.py $MODEL_FLAGS --model_path models/128_512_upsampler.pt $SAMPLE_FLAGS --base_samples 128_samples.npz

LSUN models

These models are class-unconditional and correspond to a single LSUN class. Here, we show how to sample from lsun_bedroom.pt, but the other two LSUN checkpoints should work as well:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python image_sample.py $MODEL_FLAGS --model_path models/lsun_bedroom.pt $SAMPLE_FLAGS

You can sample from lsun_horse_nodropout.pt by changing the dropout flag:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python image_sample.py $MODEL_FLAGS --model_path models/lsun_horse_nodropout.pt $SAMPLE_FLAGS

Note that for these models, the best samples result from using 1000 timesteps:

SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 1000"

Results

This table summarizes our ImageNet results for pure guided diffusion models:

Dataset FID Precision Recall
ImageNet 64x64 2.07 0.74 0.63
ImageNet 128x128 2.97 0.78 0.59
ImageNet 256x256 4.59 0.82 0.52
ImageNet 512x512 7.72 0.87 0.42

This table shows the best results for high resolutions when using upsampling and guidance together:

Dataset FID Precision Recall
ImageNet 256x256 3.94 0.83 0.53
ImageNet 512x512 3.85 0.84 0.53

Finally, here are the unguided results on individual LSUN classes:

Dataset FID Precision Recall
LSUN Bedroom 1.90 0.66 0.51
LSUN Cat 5.57 0.63 0.52
LSUN Horse 2.57 0.71 0.55

Training models

Training diffusion models is described in the parent repository. Training a classifier is similar. We assume you have put training hyperparameters into a TRAIN_FLAGS variable, and classifier hyperparameters into a CLASSIFIER_FLAGS variable. Then you can run:

mpiexec -n N python scripts/classifier_train.py --data_dir path/to/imagenet $TRAIN_FLAGS $CLASSIFIER_FLAGS

Make sure to divide the batch size in TRAIN_FLAGS by the number of MPI processes you are using.

Here are flags for training the 128x128 classifier. You can modify these for training classifiers at other resolutions:

TRAIN_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 3e-4 --save_interval 10000 --weight_decay 0.05"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True"

For sampling from a 128x128 classifier-guided model, 25 step DDIM:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --image_size 128 --learn_sigma True --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True --classifier_scale 1.0 --classifier_use_fp16 True"
SAMPLE_FLAGS="--batch_size 4 --num_samples 50000 --timestep_respacing ddim25 --use_ddim True"
mpiexec -n N python scripts/classifier_sample.py \
    --model_path /path/to/model.pt \
    --classifier_path path/to/classifier.pt \
    $MODEL_FLAGS $CLASSIFIER_FLAGS $SAMPLE_FLAGS

To sample for 250 timesteps without DDIM, replace --timestep_respacing ddim25 to --timestep_respacing 250, and replace --use_ddim True with --use_ddim False.

Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
MVS2D: Efficient Multi-view Stereo via Attention-Driven 2D Convolutions

MVS2D: Efficient Multi-view Stereo via Attention-Driven 2D Convolutions Project Page | Paper If you find our work useful for your research, please con

96 Jan 04, 2023
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
Code for testing various M1 Chip benchmarks with TensorFlow.

M1, M1 Pro, M1 Max Machine Learning Speed Test Comparison This repo contains some sample code to benchmark the new M1 MacBooks (M1 Pro and M1 Max) aga

Daniel Bourke 348 Jan 04, 2023
Sparse Physics-based and Interpretable Neural Networks

Sparse Physics-based and Interpretable Neural Networks for PDEs This repository contains the code and manuscript for research done on Sparse Physics-b

28 Jan 03, 2023
This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

Nils L. Westhausen 182 Jan 07, 2023
DenseNet Implementation in Keras with ImageNet Pretrained Models

DenseNet-Keras with ImageNet Pretrained Models This is an Keras implementation of DenseNet with ImageNet pretrained weights. The weights are converted

Felix Yu 568 Oct 31, 2022
Chinese Mandarin tts text-to-speech 中文 (普通话) 语音 合成 , by fastspeech 2 , implemented in pytorch, using waveglow as vocoder,

Chinese mandarin text to speech based on Fastspeech2 and Unet This is a modification and adpation of fastspeech2 to mandrin(普通话). Many modifications t

291 Jan 02, 2023
CO-PILOT: COllaborative Planning and reInforcement Learning On sub-Task curriculum

CO-PILOT CO-PILOT: COllaborative Planning and reInforcement Learning On sub-Task curriculum, NeurIPS 2021, Shuang Ao, Tianyi Zhou, Guodong Long, Qingh

Shuang Ao 1 Feb 18, 2022
Make your AirPlay devices as TTS speakers

Apple AirPlayer Home Assistant integration component, make your AirPlay devices as TTS speakers. Before Use 2021.6.X or earlier Apple Airplayer compon

George Zhao 117 Dec 15, 2022
MAT: Mask-Aware Transformer for Large Hole Image Inpainting

MAT: Mask-Aware Transformer for Large Hole Image Inpainting (CVPR2022, Oral) Wenbo Li, Zhe Lin, Kun Zhou, Lu Qi, Yi Wang, Jiaya Jia [Paper] News This

254 Dec 29, 2022
Unsupervised CNN for Single View Depth Estimation: Geometry to the Rescue

Realtime Unsupervised Depth Estimation from an Image This is the caffe implementation of our paper "Unsupervised CNN for single view depth estimation:

Ravi Garg 227 Nov 28, 2022
Production First and Production Ready End-to-End Speech Recognition Toolkit

WeNet 中文版 Discussions | Docs | Papers | Runtime (x86) | Runtime (android) | Pretrained Models We share neural Net together. The main motivation of WeN

2.7k Jan 04, 2023
AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision

NeurIPS 2021 Paper "Learning with Algorithmic Supervision via Continuous Relaxations"

Felix Petersen 76 Jan 01, 2023
Benchmark datasets, data loaders, and evaluators for graph machine learning

Overview The Open Graph Benchmark (OGB) is a collection of benchmark datasets, data loaders, and evaluators for graph machine learning. Datasets cover

1.5k Jan 05, 2023
LBBA-boosted WSOD

LBBA-boosted WSOD Summary Our code is based on ruotianluo/pytorch-faster-rcnn and WSCDN Sincerely thanks for your resources. Newer version of our code

Martin Dong 20 Sep 19, 2022
Official implementation of the NeurIPS'21 paper 'Conditional Generation Using Polynomial Expansions'.

Conditional Generation Using Polynomial Expansions Official implementation of the conditional image generation experiments as described on the NeurIPS

Grigoris 4 Aug 07, 2022
Pytorch implementation code for [Neural Architecture Search for Spiking Neural Networks]

Neural Architecture Search for Spiking Neural Networks Pytorch implementation code for [Neural Architecture Search for Spiking Neural Networks] (https

Intelligent Computing Lab at Yale University 28 Nov 18, 2022
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Dec 28, 2022
LIMEcraft: Handcrafted superpixel selectionand inspection for Visual eXplanations

LIMEcraft LIMEcraft: Handcrafted superpixel selectionand inspection for Visual eXplanations The LIMEcraft algorithm is an explanatory method based on

MI^2 DataLab 4 Aug 01, 2022
SVG Icon processing tool for C++

BAWR This is a tool to automate the icons generation from sets of svg files into fonts and atlases. The main purpose of this tool is to add it to the

Frank David Martínez M 66 Dec 14, 2022