Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Related tags

Deep LearningSB-FBSDE
Overview

Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory [ICLR 2022]

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that generalizes score-based models to fully nonlinear forward and backward diffusions.

SB-FBSDE result

This repo is co-maintained by Guan-Horng Liu and Tianrong Chen. Contact us if you have any questions! If you find this library useful, please cite ⬇️

@inproceedings{chen2022likelihood,
  title={Likelihood Training of Schr{\"o}dinger Bridge using Forward-Backward SDEs Theory},
  author={Chen, Tianrong and Liu, Guan-Horng and Theodorou, Evangelos A},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment sb-fbsde with

conda env create --file requirements.yaml python=3
conda activate sb-fbsde

Training

python main.py \
  --problem-name <PROBLEM_NAME> \
  --forward-net <FORWARD_NET> \
  --backward-net <BACKWARD_NET> \
  --num-FID-sample <NUM_FID_SAMPLE> \ # add this flag only for CIFAR-10
  --dir <DIR> \
  --log-tb 

To train an SB-FBSDE from scratch, run the above command, where

  • PROBLEM_NAME is the dataset. We support gmm (2D mixture of Gaussian), checkerboard (2D toy dataset), mnist, celebA32, celebA64, cifar10.
  • FORWARD_NET & BACKWARD_NET are the deep networks for forward and backward drifts. We support Unet, nscnpp, and a toy network for 2D datasets.
  • NUM_FID_SAMPLE is the number of generated images used to evaluate FID locally. We recommend 10000 for training CIFAR-10. Note that this requires first downloading the FID statistics checkpoint.
  • DIR specifies where the results (e.g. snapshots during training) shall be stored.
  • log-tb enables logging with Tensorboard.

Additionally, use --load to restore previous checkpoint or pre-trained model. For training CIFAR-10 specifically, we support loading the pre-trained NCSN++ as the backward policy of the first SB training stage (this is because the first SB training stage can degenerate to denoising score matching under proper initialization; see more details in Appendix D of our paper).

Other configurations are detailed in options.py. The default configurations for each dataset are provided in the configs folder.

Evaluating the CIFAR-10 Checkpoint

To evaluate SB-FBSDE on CIFAR-10 (we achieve FID 3.01 and NLL 2.96), create a folder checkpoint then download the model checkpoint and FID statistics checkpoint either from Google Drive or through the following commands.

mkdir checkpoint && cd checkpoint

# FID stat checkpoint. This's needed whenever we
# need to compute FID during training or sampling.
gdown --id 1Tm_5nbUYKJiAtz2Rr_ARUY3KIFYxXQQD 

# SB-FBSDE model checkpoint for reproducing results in the paper.
gdown --id 1Kcy2IeecFK79yZDmnky36k4PR2yGpjyg 

After downloading the checkpoints, run the following commands for computing either NLL or FID. Set the batch size --samp-bs properly depending on your hardware.

# compute NLL
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-NLL --samp-bs <BS>
# compute FID
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-FID --samp-bs <BS> --num-FID-sample 50000 --use-corrector --snr 0.15
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Vrcwatch - Supply the local time to VRChat as Avatar Parameters through OSC

English: README-EN.md VRCWatch VRCWatch は、VRChat 内のアバター向けに現在時刻を送信するためのプログラムです。 使

Kosaki Mezumona 17 Nov 30, 2022
Official Pytorch implementation of C3-GAN

Official pytorch implemenation of C3-GAN Contrastive Fine-grained Class Clustering via Generative Adversarial Networks [Paper] Authors: Yunji Kim, Jun

NAVER AI 114 Dec 02, 2022
PyTorch implementation of "VRT: A Video Restoration Transformer"

VRT: A Video Restoration Transformer Jingyun Liang, Jiezhang Cao, Yuchen Fan, Kai Zhang, Rakesh Ranjan, Yawei Li, Radu Timofte, Luc Van Gool Computer

Jingyun Liang 837 Jan 09, 2023
DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs

DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs Abstract: Image-to-image translation has recently achieved re

yaxingwang 23 Apr 14, 2022
Large-scale open domain KNOwledge grounded conVERsation system based on PaddlePaddle

Knover Knover is a toolkit for knowledge grounded dialogue generation based on PaddlePaddle. Knover allows researchers and developers to carry out eff

607 Dec 31, 2022
An efficient PyTorch implementation of the evaluation metrics in recommender systems.

recsys_metrics An efficient PyTorch implementation of the evaluation metrics in recommender systems. Overview • Installation • How to use • Benchmark

Xingdong Zuo 12 Dec 02, 2022
ConE: Cone Embeddings for Multi-Hop Reasoning over Knowledge Graphs

ConE: Cone Embeddings for Multi-Hop Reasoning over Knowledge Graphs This is the code of paper ConE: Cone Embeddings for Multi-Hop Reasoning over Knowl

MIRA Lab 33 Dec 07, 2022
Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train format

ttopt Description Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train (TT) format and maximu

5 May 23, 2022
Transfer Learning for Pose Estimation of Illustrated Characters

bizarre-pose-estimator Transfer Learning for Pose Estimation of Illustrated Characters Shuhong Chen *, Matthias Zwicker * WACV2022 [arxiv] [video] [po

Shuhong Chen 142 Dec 28, 2022
Effect of Different Encodings and Distance Functions on Quantum Instance-based Classifiers

Effect of Different Encodings and Distance Functions on Quantum Instance-based Classifiers The repository contains the code to reproduce the experimen

Alessandro Berti 4 Aug 24, 2022
Generic Foreground Segmentation in Images

Pixel Objectness The following repository contains pretrained model for pixel objectness. Please visit our project page for the paper and visual resul

Suyog Jain 157 Nov 21, 2022
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
The Official PyTorch Implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 spotlight paper)

Official PyTorch implementation of "VAEBM: A Symbiosis between Variational Autoencoders and Energy-based Models" (ICLR 2021 Spotlight Paper) Zhisheng

NVIDIA Research Projects 45 Dec 26, 2022
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Gen Li 91 Dec 23, 2022
Jupyter Dock is a set of Jupyter Notebooks for performing molecular docking protocols interactively, as well as visualizing, converting file formats and analyzing the results.

Molecular Docking integrated in Jupyter Notebooks Description | Citation | Installation | Examples | Limitations | License Table of content Descriptio

Angel J. Ruiz Moreno 173 Dec 25, 2022
Python tools for 3D face: 3DMM, Mesh processing(transform, camera, light, render), 3D face representations.

face3d: Python tools for processing 3D face Introduction This project implements some basic functions related to 3D faces. You can use this to process

Yao Feng 2.3k Dec 30, 2022
Lua-parser-lark - An out-of-box Lua parser written in Lark

An out-of-box Lua parser written in Lark Such parser handles a relaxed version o

Taine Zhao 2 Jul 19, 2022
Facestar dataset. High quality audio-visual recordings of human conversational speech.

Facestar Dataset Description Existing audio-visual datasets for human speech are either captured in a clean, controlled environment but contain only a

Meta Research 87 Dec 21, 2022
DeepLab is a state-of-art deep learning system for semantic image segmentation built on top of Caffe.

DeepLab Introduction DeepLab is a state-of-art deep learning system for semantic image segmentation built on top of Caffe. It combines densely-compute

Ali 234 Nov 14, 2022
AI-Bot - 一个基于watermelon改造的OpenAI-GPT-2的智能机器人

AI-Bot 一个基于watermelon改造的OpenAI-GPT-2的智能机器人 在Binder上直接运行测试 目前有两种实现方式 TF2的GPT-2 TF

9 Nov 16, 2022