A minimalist implementation of score-based diffusion model

Overview

sdeflow-light

This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper

"A Variational Perspective on Diffusion-Based Generative Models and Score Matching" by Chin-Wei Huang, Jae Hyun Lim and Aaron Courville [arXiv]

Also see the concurrent work by Yang Song & Conor Durkan where they used the same idea to obtain state-of-the-art likelihood estimates.

Experiments on Swissroll

Here's a Colab notebook which contains an example for training a model on the Swissroll dataset.

Open In Colab

In this notebook, you'll see how to train the model using score matching loss, how to evaluate the ELBO of the plug-in reverse SDE, and how to sample from it. It also includes a snippet to sample from a family of plug-in reverse SDEs (parameterized by λ) mentioned in Appendix C of the paper.

Below are the trajectories of λ=0 (the reverse SDE used in Song et al.) and λ=1 (equivalent ODE) when we plug in the learned score / drift function. This corresponds to Figure 5 of the paper. drawing drawing

Experiments on MNIST and CIFAR-10

This repository contains one main training loop (train_img.py). The model is trained to minimize the denoising score matching loss by calling the .dsm(x) loss function, and evaluated using the following ELBO, by calling .elbo_random_t_slice(x)

score-elbo

where the divergence (sum of the diagonal entries of the Jacobian) is estimated using the Hutchinson trace estimator.

It's a minimalist codebase in the sense that we do not use fancy optimizer (we only use Adam with the default setup) or learning rate scheduling. We use the modified U-net architecture from Denoising Diffusion Probabilistic Models by Jonathan Ho.

A key difference from Song et al. is that instead of parameterizing the score function s, here we parameterize the drift term a (where they are related by a=gs and g is the diffusion coefficient). That is, a is the U-net.

Parameterization: Our original generative & inference SDEs are

  • dX = mu dt + sigma dBt
  • dY = (-mu + sigma*a) ds + sigma dBs

We reparameterize it as

  • dX = (ga - f) dt + g dBt
  • dY = f ds + g dBs

by letting mu = ga - f, and sigma = g. (since f and g are fixed, we only have one degree of freedom, which is a). Alternatively, one can parameterize s (e.g. using the U-net), and just let a=gs.

How it works

Here's an example command line for running an experiment

python train_img.py --dataroot=[DATAROOT] --saveroot=[SAVEROOT] --expname=[EXPNAME] \
    --dataset=cifar --print_every=2000 --sample_every=2000 --checkpoint_every=2000 --num_steps=1000 \
    --batch_size=128 --lr=0.0001 --num_iterations=100000 --real=True --debias=False

Setting --debias to be False uses uniform sampling for the time variable, whereas setting it to be True uses a non-uniform sampling strategy to debias the gradient estimate described in the paper. Below are the bits-per-dim and the corresponding standard error of the test set recorded during training (orange for --debias=True and blue for --debias=False).

drawing drawing

Here are some samples (debiased on the right)

drawing drawing

It takes about 14 hrs to finish 100k iterations on a V100 GPU.

Owner
Chin-Wei Huang
Chin-Wei Huang
[CIKM 2021] Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. This repo contains the PyTorch code and implementation for the paper E

Akuchi 18 Dec 22, 2022
A curated list of awesome neural radiance fields papers

Awesome Neural Radiance Fields A curated list of awesome neural radiance fields papers, inspired by awesome-computer-vision. How to submit a pull requ

Yen-Chen Lin 3.9k Dec 27, 2022
Official PyTorch implementation of Learning Intra-Batch Connections for Deep Metric Learning (ICML 2021) published at International Conference on Machine Learning

About This repository the official PyTorch implementation of Learning Intra-Batch Connections for Deep Metric Learning. The config files contain the s

Dynamic Vision and Learning Group 41 Dec 10, 2022
Official code for "Stereo Waterdrop Removal with Row-wise Dilated Attention (IROS2021)"

Stereo-Waterdrop-Removal-with-Row-wise-Dilated-Attention This repository includes official codes for "Stereo Waterdrop Removal with Row-wise Dilated A

29 Oct 01, 2022
A Model for Natural Language Attack on Text Classification and Inference

TextFooler A Model for Natural Language Attack on Text Classification and Inference This is the source code for the paper: Jin, Di, et al. "Is BERT Re

Di Jin 418 Dec 16, 2022
Code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizability of Cross-Task Neural Architecture Search.

TransNAS-Bench-101 This repository contains the publishable code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizabili

Yawen Duan 17 Nov 20, 2022
View model summaries in PyTorch!

torchinfo (formerly torch-summary) Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensor

Tyler Yep 1.5k Jan 05, 2023
Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021.

UniRE Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021. Requirements python: 3.7.6 pytorch: 1.8.1 transformers:

Wang Yijun 109 Nov 29, 2022
Benchmarking Pipeline for Prediction of Protein-Protein Interactions

B4PPI Benchmarking Pipeline for the Prediction of Protein-Protein Interactions How this benchmarking pipeline has been built, and how to use it, is de

Loïc Lannelongue 4 Jun 27, 2022
Kaggle competition: Springleaf Marketing Response

PruebaEnel Prueba Kaggle-Springleaf-master Prueba Kaggle-Springleaf Kaggle competition: Springleaf Marketing Response Competencia de Kaggle: Marketing

1 Feb 09, 2022
Multi-Modal Machine Learning toolkit based on PyTorch.

简体中文 | English TorchMM 简介 多模态学习工具包 TorchMM 旨在于提供模态联合学习和跨模态学习算法模型库,为处理图片文本等多模态数据提供高效的解决方案,助力多模态学习应用落地。 近期更新 2022.1.5 发布 TorchMM 初始版本 v1.0 特性 丰富的任务场景:工具

njustkmg 1 Jan 05, 2022
performing moving objects segmentation using image processing techniques with opencv and numpy

Moving Objects Segmentation On this project I tried to perform moving objects segmentation using background subtraction technique. the introduced meth

Mohamed Magdy 15 Dec 12, 2022
A series of convenience functions to make basic image processing operations such as translation, rotation, resizing, skeletonization, and displaying Matplotlib images easier with OpenCV and Python.

imutils A series of convenience functions to make basic image processing functions such as translation, rotation, resizing, skeletonization, and displ

Adrian Rosebrock 4.3k Jan 08, 2023
Graph WaveNet apdapted for brain connectivity analysis.

Graph WaveNet for brain network analysis This is the implementation of the Graph WaveNet model used in our manuscript: S. Wein , A. Schüller, A. M. To

4 Dec 17, 2022
3D-printable hand-strapped keyboard

Note: This repo has not been cleaned up and prepared for general consumption at all. This is just a dump of the project files. If there is any interes

Wojciech Baranowski 41 Dec 31, 2022
State-to-Distribution (STD) Model

State-to-Distribution (STD) Model In this repository we provide exemplary code on how to construct and evaluate a state-to-distribution (STD) model fo

<a href=[email protected]"> 2 Apr 07, 2022
Inference pipeline for our participation in the FeTA challenge 2021.

feta-inference Inference pipeline for our participation in the FeTA challenge 2021. Team name: TRABIT Installation Download the two folders in https:/

Lucas Fidon 2 Apr 13, 2022
The official PyTorch code for NeurIPS 2021 ML4AD Paper, "Does Thermal data make the detection systems more reliable?"

MultiModal-Collaborative (MMC) Learning Framework for integrating RGB and Thermal spectral modalities This is the official code for NeurIPS 2021 Machi

NeurAI 12 Nov 02, 2022
Make a Turtlebot3 follow a figure 8 trajectory and create a robot arm and make it follow a trajectory

HW2 - ME 495 Overview Part 1: Makes the robot move in a figure 8 shape. The robot starts moving when launched on a real turtlebot3 and can be paused a

Devesh Bhura 0 Oct 21, 2022
PyTorch implementation of Neural View Synthesis and Matching for Semi-Supervised Few-Shot Learning of 3D Pose

Neural View Synthesis and Matching for Semi-Supervised Few-Shot Learning of 3D Pose Release Notes The official PyTorch implementation of Neural View S

Angtian Wang 20 Oct 09, 2022