PyTorch reimplementation of Diffusion Models

Overview

PyTorch pretrained Diffusion Models

A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author's TensorFlow implementation.

Quickstart

Running

pip install -e git+https://github.com/pesser/pytorch_diffusion.git#egg=pytorch_diffusion
pytorch_diffusion_demo

will start a Streamlit demo. It is recommended to run the demo with a GPU available.

demo

Usage

Diffusion models with pretrained weights for cifar10, lsun-bedroom, lsun_cat or lsun_church can be loaded as follows:

from pytorch_diffusion import Diffusion

diffusion = Diffusion.from_pretrained("lsun_church")
samples = diffusion.denoise(4)
diffusion.save(samples, "lsun_church_sample_{:02}.png")

Prefix the name with ema_ to load the averaged weights that produce better results. The U-Net model used for denoising is available via diffusion.model and can also be instantiated on its own:

from pytorch_diffusion import Model

model = Model(resolution=32,
              in_channels=3,
              out_ch=3,
              ch=128,
              ch_mult=(1,2,2,2),
              num_res_blocks=2,
              attn_resolutions=(16,),
              dropout=0.1)

This configuration example corresponds to the model used on CIFAR-10.

Producing samples

If you installed directly from github, you can find the cloned repository in <venv path>/src/pytorch_diffusion for virtual environments, and <cwd>/src/pytorch_diffusion for global installs. There, you can run

python pytorch_diffusion/diffusion.py <name> <bs> <nb>

where <name> is one of cifar10, lsun-bedroom, lsun_cat, lsun_church, or one of these names prefixed with ema_, <bs> is the batch size and <nb> the number of batches. This will produce samples from the PyTorch models and save them to results/<name>/.

Results

Evaluating 50k samples with torch-fidelity gives

Dataset EMA Framework Model FID
CIFAR10 Train no PyTorch cifar10 12.13775
TensorFlow tf_cifar10 12.30003
yes PyTorch ema_cifar10 3.21213
TensorFlow tf_ema_cifar10 3.245872
CIFAR10 Validation no PyTorch cifar10 14.30163
TensorFlow tf_cifar10 14.44705
yes PyTorch ema_cifar10 5.274105
TensorFlow tf_ema_cifar10 5.325035

To reproduce, generate 50k samples from the converted PyTorch models provided in this repo with

`python pytorch_diffusion/diffusion.py <Model> 500 100`

and with

python -c "import convert as m; m.sample_tf(500, 100, which=['cifar10', 'ema_cifar10'])"

for the original TensorFlow models.

Running conversions

The converted pytorch checkpoints are provided for download. If you want to convert them on your own, you can follow the steps described here.

Setup

This section assumes your working directory is the root of this repository. Download the pretrained TensorFlow checkpoints. It should follow the original structure,

diffusion_models_release/
  diffusion_cifar10_model/
    model.ckpt-790000.data-00000-of-00001
    model.ckpt-790000.index
    model.ckpt-790000.meta
  diffusion_lsun_bedroom_model/
    ...
  ...

Set the environment variable TFROOT to the directory where you want to store the author's repository, e.g.

export TFROOT=".."

Clone the diffusion repository,

git clone https://github.com/hojonathanho/diffusion.git ${TFROOT}/diffusion

and install their required dependencies (pip install ${TFROOT}/requirements.txt). Then add the following to your PYTHONPATH:

export PYTHONPATH=".:./scripts:${TFROOT}/diffusion:${TFROOT}/diffusion/scripts:${PYTHONPATH}"

Testing operations

To test the pytorch implementations of the required operations against their TensorFlow counterparts under random initialization and random inputs, run

python -c "import convert as m; m.test_ops()"

Converting checkpoints

To load the pretrained TensorFlow models, copy the weights into the pytorch models, check for equality on random inputs and finally save the corresponding pytorch checkpoints, run

python -c "import convert as m; m.transplant_cifar10()"
python -c "import convert as m; m.transplant_cifar10(ema=True)"
python -c "import convert as m; m.transplant_lsun_bedroom()"
python -c "import convert as m; m.transplant_lsun_bedroom(ema=True)"
python -c "import convert as m; m.transplant_lsun_cat()"
python -c "import convert as m; m.transplant_lsun_cat(ema=True)"
python -c "import convert as m; m.transplant_lsun_church()"
python -c "import convert as m; m.transplant_lsun_church(ema=True)"

Pytorch checkpoints will be saved in

diffusion_models_converted/
  diffusion_cifar10_model/
    model-790000.ckpt
  ema_diffusion_cifar10_model/
    model-790000.ckpt
  diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  ema_diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  diffusion_lsun_cat_model/
    model-1761000.ckpt
  ema_diffusion_lsun_cat_model/
    model-1761000.ckpt
  diffusion_lsun_church_model/
    model-4432000.ckpt
  ema_diffusion_lsun_church_model/
    model-4432000.ckpt

Sample TensorFlow models

To produce N samples from each of the pretrained TensorFlow models, run

python -c "import convert as m; m.sample_tf(N)"

Pass a list of model names as keyword argument which to specify which models to sample from. Samples will be saved in results/.

Owner
Patrick Esser
Patrick Esser
Implementation of the paper "Generating Symbolic Reasoning Problems with Transformer GANs"

Generating Symbolic Reasoning Problems with Transformer GANs This is the implementation of the paper Generating Symbolic Reasoning Problems with Trans

Reactive Systems Group 1 Apr 18, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
A framework for using LSTMs to detect anomalies in multivariate time series data. Includes spacecraft anomaly data and experiments from the Mars Science Laboratory and SMAP missions.

Telemanom (v2.0) v2.0 updates: Vectorized operations via numpy Object-oriented restructure, improved organization Merge branches into single branch fo

Kyle Hundman 844 Dec 28, 2022
Simple tools for logging and visualizing, loading and training

TNT TNT is a library providing powerful dataloading, logging and visualization utilities for Python. It is closely integrated with PyTorch and is desi

1.5k Jan 02, 2023
How to Train a GAN? Tips and tricks to make GANs work

(this list is no longer maintained, and I am not sure how relevant it is in 2020) How to Train a GAN? Tips and tricks to make GANs work While research

Soumith Chintala 10.8k Dec 31, 2022
SPEAR: Semi suPErvised dAta progRamming

Semi-Supervised Data Programming for Data Efficient Machine Learning SPEAR is a library for data programming with semi-supervision. The package implem

decile-team 91 Dec 06, 2022
Training DALL-E with volunteers from all over the Internet using hivemind and dalle-pytorch (NeurIPS 2021 demo)

Training DALL-E with volunteers from all over the Internet This repository is a part of the NeurIPS 2021 demonstration "Training Transformers Together

<a href=[email protected]"> 19 Dec 13, 2022
A tool to prepare websites grabbed with wget for local viewing.

makelocal A tool to prepare websites grabbed with wget for local viewing. exapmples After fetching xkcd.com with: wget -r -no-remove-listing -r -N --p

5 Apr 23, 2022
Official repository for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'21, Oral Presentation)

Official PyTorch Implementation for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'2021, Oral Presentation) HOTR: End-to-

Kakao Brain 114 Nov 28, 2022
PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Zhiqiang Shen 129 Dec 24, 2022
The official GitHub repository for the Argoverse 2 dataset.

Argoverse 2 API Official GitHub repository for the Argoverse 2 family of datasets. If you have any questions or run into any problems with either the

Argo AI 156 Dec 23, 2022
UFPR-ADMR-v2 Dataset

UFPR-ADMR-v2 Dataset The UFPR-ADMRv2 dataset contains 5,000 dial meter images obtained on-site by employees of the Energy Company of ParanĂ¡ (Copel), w

Gabriel Salomon 8 Sep 29, 2022
Concept drift monitoring for HA model servers.

{Fast, Correct, Simple} - pick three Easily compare training and production ML data & model distributions Goals Boxkite is an instrumentation library

98 Dec 15, 2022
STARCH compuets regional extreme storm physical characteristics and moisture balance based on spatiotemporal precipitation data from reanalysis or climate model data.

STARCH (Storm Tracking And Regional CHaracterization) STARCH computes regional extreme storm physical and moisture balance characteristics based on sp

Onosama 7 Oct 20, 2022
Doosan robotic arm, simulation, control, visualization in Gazebo and ROS2 for Reinforcement Learning.

Robotic Arm Simulation in ROS2 and Gazebo General Overview This repository includes: First, how to simulate a 6DoF Robotic Arm from scratch using GAZE

David Valencia 12 Jan 02, 2023
Tackling data scarcity in Speech Translation using zero-shot multilingual Machine Translation techniques

Tackling data scarcity in Speech Translation using zero-shot multilingual Machine Translation techniques This repository is derived from the NMTGMinor

Tu Anh Dinh 1 Sep 07, 2022
FasterAI: A library to make smaller and faster models with FastAI.

Fasterai fasterai is a library created to make neural network smaller and faster. It essentially relies on common compression techniques for networks

Nathan Hubens 193 Jan 01, 2023
Robocop is your personal mini voice assistant made using Python.

Robocop-VoiceAssistant To use this project, you should have python installed in your system. If you don't have python installed, install it beforehand

Sohil Khanduja 3 Feb 26, 2022
Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language (NeurIPS 2021)

VRDP (NeurIPS 2021) Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language Mingyu Ding, Zhenfang Chen, Tao Du, Pin

Mingyu Ding 36 Sep 20, 2022
Reimplementation of the paper `Human Attention Maps for Text Classification: Do Humans and Neural Networks Focus on the Same Words? (ACL2020)`

Human Attention for Text Classification Re-implementation of the paper Human Attention Maps for Text Classification: Do Humans and Neural Networks Foc

Shunsuke KITADA 15 Dec 13, 2021