PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Overview

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models

This repository will reproduce the main results from our paper:

On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models
Erik Nijkamp*, Mitch Hill*, Tian Han, Song-Chun Zhu, and Ying Nian Wu (*equal contributions)
https://arxiv.org/abs/1903.12370
AAAI 2020.

The files train_data.py and train_toy.py are PyTorch-based implementations of Algorithm 1 for image datasets and toy 2D distributions respectively. Both files will measure and plot the diagnostic values $d_{s_t}$ and $r_t$ described in Section 3 during training. The file eval.py will sample from a saved checkpoint using either unadjusted Langevin dynamics or Metropolis-Hastings adjusted Langevin dynamics. We provide an appendix ebm-anatomy-appendix.pdf that contains further practical considerations and empirical observations.

Config Files

The folder config_locker has several JSON files that reproduce different convergent and non-convergent learning outcomes for image datasets and toy distributions. Config files for evaluation of pre-trained networks are also included. The files data_config.json, toy_config.json, and eval_config.json fully explain the parameters for train_data.py, train_toy.py, and eval.py respectively.

Executable Files

To run an experiment with train_data.py, train_toy.py, or eval.py, just specify a name for the experiment folder and the location of the JSON config file:

# directory for experiment results
EXP_DIR = './name_of/new_folder/'
# json file with experiment config
CONFIG_FILE = './path_to/config.json'

before execution.

Other Files

Network structures are located in nets.py. A download function for Oxford Flowers 102 data, plotting functions, and a toy dataset class can be found in utils.py.

Diagnostics

Energy Difference and Langevin Gradient Magnitude: Both image and toy experiments will plot $d_{s_t}$ and $r_t$ (see Section 3) over training along with correlation plots as in Figure 4 (with ACF rather than PACF).

Landscape Plots: Toy experiments will plot the density and log-density (negative energy) for ground-truth, learned energy, and short-run models. Kernel density estimation is used to obtain the short-run density.

Short-Run MCMC Samples: Image data experiments will periodically visualize the short-run MCMC samples. A batch of persistent MCMC samples will also be saved for implementations that use persistent initialization for short-run sampling.

Long-Run MCMC Samples: Image data experiments have the option to obtain long-run MCMC samples during training. When log_longrun is set to true in a data config file, the training implementation will generate long-run MCMC samples at a frequency determined by log_longrun_freq. The appearance of long-run MCMC samples indicates whether the energy function assigns probability mass in realistic regions of the image space.

Pre-trained Networks

A convergent pre-trained network and non-convergent pre-trained network for the Oxford Flowers 102 dataset are available in the Releases section of the repository. The config files eval_flowers_convergent.json and eval_flowers_convergent_mh.json are set up to evaluate flowers_convergent_net.pth. The config file eval_flowers_nonconvergent.json is set up to evaluate flowers_nonconvergent_net.pth.

Contact

Please contact Mitch Hill ([email protected]) or Erik Nijkamp ([email protected]) for any questions.

You might also like...
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

PyTorch implementation of the implicit Q-learning algorithm (IQL)
PyTorch implementation of the implicit Q-learning algorithm (IQL)

Implicit-Q-Learning (IQL) PyTorch implementation of the implicit Q-learning algorithm IQL (Paper) Currently only implemented for online learning. Offl

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

A pytorch reprelication of the model-based reinforcement learning algorithm MBPO
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).
PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).

PyGAD: Genetic Algorithm in Python PyGAD is an open-source easy-to-use Python 3 library for building the genetic algorithm and optimizing machine lear

Comments
  • Step size in Langevin Dynamics

    Step size in Langevin Dynamics

    Hi, in your code, when you do the langevin dynamics, you run x_s_t.data += - f_prime + config['epsilon'] * t.randn_like(x_s_t) However, does this mean that the step size for the gradient f_prim is 1? Should we run x_s_t.data += - 0.5*config['epsilon']**2*f_prime + config['epsilon'] * t.randn_like(x_s_t) instead?

    opened by XavierXiao 1
Releases(v1.0)
Owner
Mitch Hill
Assistant Professor of Statistics and Data Science at UCF
Mitch Hill
Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation

Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation Official PyTorch implementation for the paper Look

Rishabh Jangir 20 Nov 24, 2022
Knowledge Management for Humans using Machine Learning & Tags

HyperTag HyperTag helps humans intuitively express how they think about their files using tags and machine learning.

Ravn Tech, Inc. 165 Nov 04, 2022
Repository for the NeurIPS 2021 paper: "Exploiting Domain-Specific Features to Enhance Domain Generalization".

meta-Domain Specific-Domain Invariant (mDSDI) Source code implementation for the paper: Manh-Ha Bui, Toan Tran, Anh Tuan Tran, Dinh Phung. "Exploiting

VinAI Research 12 Nov 25, 2022
Unsupervised Image to Image Translation with Generative Adversarial Networks

Unsupervised Image to Image Translation with Generative Adversarial Networks Paper: Unsupervised Image to Image Translation with Generative Adversaria

Hao 71 Oct 30, 2022
A universal memory dumper using Frida

Fridump Fridump (v0.1) is an open source memory dumping tool, primarily aimed to penetration testers and developers. Fridump is using the Frida framew

551 Jan 07, 2023
Unofficial Pytorch Implementation of WaveGrad2

WaveGrad 2 — Unofficial PyTorch Implementation WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis Unofficial PyTorch+Lightning Implementati

MINDs Lab 104 Nov 29, 2022
A-ESRGAN aims to provide better super-resolution images by using multi-scale attention U-net discriminators.

A-ESRGAN: Training Real-World Blind Super-Resolution with Attention-based U-net Discriminators The authors are hidden for the purpose of double blind

77 Dec 16, 2022
Build tensorflow keras model pipelines in a single line of code. Created by Ram Seshadri. Collaborators welcome. Permission granted upon request.

deep_autoviml Build keras pipelines and models in a single line of code! Table of Contents Motivation How it works Technology Install Usage API Image

AutoViz and Auto_ViML 102 Dec 17, 2022
Pytorch implementation of AngularGrad: A New Optimization Technique for Angular Convergence of Convolutional Neural Networks

AngularGrad Optimizer This repository contains the oficial implementation for AngularGrad: A New Optimization Technique for Angular Convergence of Con

mario 124 Sep 16, 2022
Code for "Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance" at NeurIPS 2021

Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance Justin Lim, Christina X Ji, Michael Oberst, Saul Blecker, Leor

Sontag Lab 3 Feb 03, 2022
Neural Oblivious Decision Ensembles

Neural Oblivious Decision Ensembles A supplementary code for anonymous ICLR 2020 submission. What does it do? It learns deep ensembles of oblivious di

25 Sep 21, 2022
This is a model made out of Neural Network specifically a Convolutional Neural Network model

This is a model made out of Neural Network specifically a Convolutional Neural Network model. This was done with a pre-built dataset from the tensorflow and keras packages. There are other alternativ

9 Oct 18, 2022
g2o: A General Framework for Graph Optimization

g2o - General Graph Optimization Linux: Windows: g2o is an open-source C++ framework for optimizing graph-based nonlinear error functions. g2o has bee

Rainer Kümmerle 2.5k Dec 30, 2022
An Open-Source Package for Information Retrieval.

OpenMatch An Open-Source Package for Information Retrieval. 😃 What's New Top Spot on TREC-COVID Challenge (May 2020, Round2) The twin goals of the ch

THUNLP 439 Dec 27, 2022
Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

William Falcon 141 Dec 30, 2022
PyTorch implementation of the paper Dynamic Token Normalization Improves Vision Transfromers.

Dynamic Token Normalization Improves Vision Transformers This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision T

Wenqi Shao 20 Oct 09, 2022
CellRank's reproducibility repository.

CellRank's reproducibility repository We believe that reproducibility is key and have made it as simple as possible to reproduce our results. Please e

Theis Lab 8 Oct 08, 2022
NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

Göktuğ Karakaşlı 16 Dec 05, 2022
CLIPImageClassifier wraps clip image model from transformers

CLIPImageClassifier CLIPImageClassifier wraps clip image model from transformers. CLIPImageClassifier is initialized with the argument classes, these

Jina AI 6 Sep 12, 2022
Code for Phase diagram of Stochastic Gradient Descent in high-dimensional two-layer neural networks

Phase diagram of Stochastic Gradient Descent in high-dimensional two-layer neural networks Under construction. Description Code for Phase diagram of S

Rodrigo Veiga 3 Nov 24, 2022