Reference implementation for Deep Unsupervised Learning using Nonequilibrium Thermodynamics

Overview

Diffusion Probabilistic Models

This repository provides a reference implementation of the method described in the paper:

Deep Unsupervised Learning using Nonequilibrium Thermodynamics
Jascha Sohl-Dickstein, Eric A. Weiss, Niru Maheswaranathan, Surya Ganguli
International Conference on Machine Learning, 2015
http://arxiv.org/abs/1503.03585

This implementation builds a generative model of data by training a Gaussian diffusion process to transform a noise distribution into a data distribution in a fixed number of time steps. The mean and covariance of the diffusion process are parameterized using deep supervised learning. The resulting model is tractable to train, easy to exactly sample from, allows the probability of datapoints to be cheaply evaluated, and allows straightforward computation of conditional and posterior distributions.

Using the Software

In order to train a diffusion probabilistic model on the default dataset of MNIST, install dependencies (see below), and then run python train.py.

Dependencies

  1. Install Blocks and its dependencies following these instructions
  2. Setup Fuel and download MNIST following these instructions.

As of October 16, 2015 this code requires the bleeding edge, rather than stable, versions of both Blocks and Fuel. (thanks to David Hofmann for pointing out that the stable release will not work due to an interface change)

Output

The objective function being minimized is the bound on the negative log likelihood in bits per pixel, minus the negative log likelihood under an identity-covariance Gaussian model. That is, it is the negative of the number in the rightmost column in Table 1 in the paper.

Logging information is printed to the console once per training epoch, including the current value of the objective on the training set.

Figures showing samples from the model, parameters, gradients, and training progress are also output periodically (every 25 epochs by default -- see train.py).

The samples from the model are of three types -- standard samples, samples inpainting the left half of masked images, and samples denoising images with Gaussian noise added (by default, the signal-to-noise ratio is 1). This demonstrates the straightforward way in which inpainting, denoising, and sampling from a posterior in general can be performed using this framework.

Here are samples generated by this code after 825 training epochs on MNIST, trained using the command run train.py:

Here are samples generated by this code after 1700 training epochs on CIFAR-10, trained using the command run train.py --batch-size 200 --dataset CIFAR10 --model-args "n_hidden_dense_lower=1000,n_hidden_dense_lower_output=5,n_hidden_conv=100,n_layers_conv=6,n_layers_dense_lower=6,n_layers_dense_upper=4,n_hidden_dense_upper=100":

Miscellaneous

Different nonlinearities - In the paper, we used softplus units in the convolutional layers, and tanh units in the dense layers. In this implementation, I use leaky ReLU units everywhere.

Original source code - This repository is a refactoring of the code used to run the experiments in the published paper. In the spirit of reproducibility, if you email me a request I am willing to share the original source code. It is poorly commented and held together with duct tape though. For most applications, you will be better off using the reference implementation provided here.

Contact - I would love to hear from you. Let me know what goes right/wrong! [email protected]

Owner
Jascha Sohl-Dickstein
Jascha Sohl-Dickstein
KaziText is a tool for modelling common human errors.

KaziText KaziText is a tool for modelling common human errors. It estimates probabilities of individual error types (so called aspects) from grammatic

ÚFAL 3 Nov 24, 2022
Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight)

About Code release for Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy (ICLR 2022 Spotlight)

THUML @ Tsinghua University 221 Dec 31, 2022
A Python library that provides a simplified alternative to DBAPI 2

A Python library that provides a simplified alternative to DBAPI 2. It provides a facade in front of DBAPI 2 drivers.

Tony Locke 44 Nov 17, 2021
the official implementation of the paper "Isometric Multi-Shape Matching" (CVPR 2021)

Isometric Multi-Shape Matching (IsoMuSh) Paper-CVF | Paper-arXiv | Video | Code Citation If you find our work useful in your research, please consider

Maolin Gao 9 Jul 17, 2022
Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Marius 107 Jan 03, 2023
Composing methods for ML training efficiency

MosaicML Composer contains a library of methods, and ways to compose them together for more efficient ML training.

MosaicML 2.8k Jan 08, 2023
Sequential GCN for Active Learning

Sequential GCN for Active Learning Please cite if using the code: Link to paper. Requirements: python 3.6+ torch 1.0+ pip libraries: tqdm, sklearn, sc

45 Dec 26, 2022
The official github repository for Towards Continual Knowledge Learning of Language Models

Towards Continual Knowledge Learning of Language Models This is the official github repository for Towards Continual Knowledge Learning of Language Mo

Joel Jang | 장요엘 65 Jan 07, 2023
CLOCs: Camera-LiDAR Object Candidates Fusion for 3D Object Detection

CLOCs is a novel Camera-LiDAR Object Candidates fusion network. It provides a low-complexity multi-modal fusion framework that improves the performance of single-modality detectors. CLOCs operates on

Su Pang 254 Dec 16, 2022
git《FSCE: Few-Shot Object Detection via Contrastive Proposal Encoding》(CVPR 2021) GitHub: [fig8]

FSCE: Few-Shot Object Detection via Contrastive Proposal Encoding (CVPR 2021) This repo contains the implementation of our state-of-the-art fewshot ob

233 Dec 29, 2022
Code release for "Transferable Semantic Augmentation for Domain Adaptation" (CVPR 2021)

Transferable Semantic Augmentation for Domain Adaptation Code release for "Transferable Semantic Augmentation for Domain Adaptation" (CVPR 2021) Paper

66 Dec 16, 2022
g2o: A General Framework for Graph Optimization

g2o - General Graph Optimization Linux: Windows: g2o is an open-source C++ framework for optimizing graph-based nonlinear error functions. g2o has bee

Rainer Kümmerle 2.5k Dec 30, 2022
🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗

🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗 This year's first semester Club Info challenge will put you at the head of a car racing

ClubINFO INGI (UCLouvain) 6 Dec 10, 2021
This is a collection of our NAS and Vision Transformer work.

AutoML - Neural Architecture Search This is a collection of our AutoML-NAS work iRPE (NEW): Rethinking and Improving Relative Position Encoding for Vi

Microsoft 832 Jan 08, 2023
MILK: Machine Learning Toolkit

MILK: MACHINE LEARNING TOOLKIT Machine Learning in Python Milk is a machine learning toolkit in Python. Its focus is on supervised classification with

Luis Pedro Coelho 610 Dec 14, 2022
CVPR 2021: "The Spatially-Correlative Loss for Various Image Translation Tasks"

Spatially-Correlative Loss arXiv | website We provide the Pytorch implementation of "The Spatially-Correlative Loss for Various Image Translation Task

Chuanxia Zheng 89 Jan 04, 2023
MOpt-AFL provided by the paper "MOPT: Optimized Mutation Scheduling for Fuzzers"

MOpt-AFL 1. Description MOpt-AFL is a AFL-based fuzzer that utilizes a customized Particle Swarm Optimization (PSO) algorithm to find the optimal sele

172 Dec 18, 2022
Line-level Handwritten Text Recognition (HTR) system implemented with TensorFlow.

Line-level Handwritten Text Recognition with TensorFlow This model is an extended version of the Simple HTR system implemented by @Harald Scheidl and

Hoàng Tùng Lâm (Linus) 72 May 07, 2022
Pytorch based library to rank predicted bounding boxes using text/image user's prompts.

pytorch_clip_bbox: Implementation of the CLIP guided bbox ranking for Object Detection. Pytorch based library to rank predicted bounding boxes using t

Sergei Belousov 50 Nov 27, 2022
Gym-TORCS is the reinforcement learning (RL) environment in TORCS domain with OpenAI-gym-like interface.

Gym-TORCS Gym-TORCS is the reinforcement learning (RL) environment in TORCS domain with OpenAI-gym-like interface. TORCS is the open-rource realistic

naoto yoshida 400 Dec 27, 2022