A minimalist implementation of score-based diffusion model

Overview

sdeflow-light

This is a minimalist codebase for training score-based diffusion models (supporting MNIST and CIFAR-10) used in the following paper

"A Variational Perspective on Diffusion-Based Generative Models and Score Matching" by Chin-Wei Huang, Jae Hyun Lim and Aaron Courville [arXiv]

Also see the concurrent work by Yang Song & Conor Durkan where they used the same idea to obtain state-of-the-art likelihood estimates.

Experiments on Swissroll

Here's a Colab notebook which contains an example for training a model on the Swissroll dataset.

Open In Colab

In this notebook, you'll see how to train the model using score matching loss, how to evaluate the ELBO of the plug-in reverse SDE, and how to sample from it. It also includes a snippet to sample from a family of plug-in reverse SDEs (parameterized by λ) mentioned in Appendix C of the paper.

Below are the trajectories of λ=0 (the reverse SDE used in Song et al.) and λ=1 (equivalent ODE) when we plug in the learned score / drift function. This corresponds to Figure 5 of the paper. drawing drawing

Experiments on MNIST and CIFAR-10

This repository contains one main training loop (train_img.py). The model is trained to minimize the denoising score matching loss by calling the .dsm(x) loss function, and evaluated using the following ELBO, by calling .elbo_random_t_slice(x)

score-elbo

where the divergence (sum of the diagonal entries of the Jacobian) is estimated using the Hutchinson trace estimator.

It's a minimalist codebase in the sense that we do not use fancy optimizer (we only use Adam with the default setup) or learning rate scheduling. We use the modified U-net architecture from Denoising Diffusion Probabilistic Models by Jonathan Ho.

A key difference from Song et al. is that instead of parameterizing the score function s, here we parameterize the drift term a (where they are related by a=gs and g is the diffusion coefficient). That is, a is the U-net.

Parameterization: Our original generative & inference SDEs are

  • dX = mu dt + sigma dBt
  • dY = (-mu + sigma*a) ds + sigma dBs

We reparameterize it as

  • dX = (ga - f) dt + g dBt
  • dY = f ds + g dBs

by letting mu = ga - f, and sigma = g. (since f and g are fixed, we only have one degree of freedom, which is a). Alternatively, one can parameterize s (e.g. using the U-net), and just let a=gs.

How it works

Here's an example command line for running an experiment

python train_img.py --dataroot=[DATAROOT] --saveroot=[SAVEROOT] --expname=[EXPNAME] \
    --dataset=cifar --print_every=2000 --sample_every=2000 --checkpoint_every=2000 --num_steps=1000 \
    --batch_size=128 --lr=0.0001 --num_iterations=100000 --real=True --debias=False

Setting --debias to be False uses uniform sampling for the time variable, whereas setting it to be True uses a non-uniform sampling strategy to debias the gradient estimate described in the paper. Below are the bits-per-dim and the corresponding standard error of the test set recorded during training (orange for --debias=True and blue for --debias=False).

drawing drawing

Here are some samples (debiased on the right)

drawing drawing

It takes about 14 hrs to finish 100k iterations on a V100 GPU.

Owner
Chin-Wei Huang
Chin-Wei Huang
The ARCA23K baseline system

ARCA23K Baseline System This is the source code for the baseline system associated with the ARCA23K dataset. Details about ARCA23K and the baseline sy

4 Jul 02, 2022
TAPEX: Table Pre-training via Learning a Neural SQL Executor

TAPEX: Table Pre-training via Learning a Neural SQL Executor The official repository which contains the code and pre-trained models for our paper TAPE

Microsoft 157 Dec 28, 2022
🗣️ Microsoft Edge TTS for Home Assistant, no need for app_key

Microsoft Edge TTS for Home Assistant This component is based on the TTS service of Microsoft Edge browser, no need to apply for app_key. Install Down

152 Dec 31, 2022
BMW TechOffice MUNICH 148 Dec 21, 2022
abess: Fast Best-Subset Selection in Python and R

abess: Fast Best-Subset Selection in Python and R Overview abess (Adaptive BEst Subset Selection) library aims to solve general best subset selection,

297 Dec 21, 2022
以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的斗地主ai

ddz-ai 介绍 斗地主是一种扑克游戏。游戏最少由3个玩家进行,用一副54张牌(连鬼牌),其中一方为地主,其余两家为另一方,双方对战,先出完牌的一方获胜。 ddz-ai以孤立语假设和宽度优先搜索为基础,构建了一种多通道堆叠注意力Transformer结构的系统,使其经过大量训练后,能在实际游戏中获

freefuiiismyname 88 May 15, 2022
Out-of-boundary View Synthesis towards Full-frame Video Stabilization

Out-of-boundary View Synthesis towards Full-frame Video Stabilization Introduction | Update | Results Demo | Introduction This repository contains the

25 Oct 10, 2022
SWA Object Detection

SWA Object Detection This project hosts the scripts for training SWA object detectors, as presented in our paper: @article{zhang2020swa, title={SWA

237 Nov 28, 2022
Build upon neural radiance fields to create a scene-specific implicit 3D semantic representation, Semantic-NeRF

Semantic-NeRF: Semantic Neural Radiance Fields Project Page | Video | Paper | Data In-Place Scene Labelling and Understanding with Implicit Scene Repr

Shuaifeng Zhi 243 Jan 07, 2023
GrailQA: Strongly Generalizable Question Answering

GrailQA is a new large-scale, high-quality KBQA dataset with 64,331 questions annotated with both answers and corresponding logical forms in different syntax (i.e., SPARQL, S-expression, etc.). It ca

OSU DKI Lab 76 Dec 21, 2022
House_prices_kaggle - Predict sales prices and practice feature engineering, RFs, and gradient boosting

House Prices - Advanced Regression Techniques Predicting House Prices with Machine Learning This project is build to enhance my knowledge about machin

Gurpreet Singh 1 Jan 01, 2022
Public implementation of "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression" from CoRL'21

Self-Supervised Reward Regression (SSRR) Codebase for CoRL 2021 paper "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression "

19 Dec 12, 2022
PyElastica is the Python implementation of Elastica, an open-source software for the simulation of assemblies of slender, one-dimensional structures using Cosserat Rod theory.

PyElastica PyElastica is the python implementation of Elastica: an open-source project for simulating assemblies of slender, one-dimensional structure

Gazzola Lab 105 Jan 09, 2023
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
Fast (simple) spectral synthesis and emission-line fitting of DESI spectra.

FastSpecFit Introduction This repository contains code and documentation to perform fast, simple spectral synthesis and emission-line fitting of DESI

5 Aug 02, 2022
Deduplicating Training Data Makes Language Models Better

Deduplicating Training Data Makes Language Models Better This repository contains code to deduplicate language model datasets as descrbed in the paper

Google Research 431 Dec 27, 2022
VGGFace2-HQ - A high resolution face dataset for face editing purpose

The first open source high resolution dataset for face swapping!!! A high resolution version of VGGFace2 for academic face editing purpose

Naiyuan Liu 232 Dec 29, 2022
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 04, 2023
Facial Action Unit Intensity Estimation via Semantic Correspondence Learning with Dynamic Graph Convolution

FAU Implementation of the paper: Facial Action Unit Intensity Estimation via Semantic Correspondence Learning with Dynamic Graph Convolution. Yingruo

Evelyn 78 Nov 29, 2022
Rendering color and depth images for ShapeNet models.

Color & Depth Renderer for ShapeNet This library includes the tools for rendering multi-view color and depth images of ShapeNet models. Physically bas

Yinyu Nie 41 Dec 19, 2022