Ladder Variational Autoencoders (LVAE) in PyTorch

Overview

Ladder Variational Autoencoders (LVAE)

PyTorch implementation of Ladder Variational Autoencoders (LVAE) [1]:

                 LVAE equation

where the variational distributions q at each layer are multivariate Normal with diagonal covariance.

Significant differences from [1] include:

  • skip connections in the generative path: conditioning on all layers above rather than only on the layer above (see for example [2])
  • spatial (convolutional) latent variables
  • free bits [3] instead of beta annealing [4]

Install requirements and run MNIST example

pip install -r requirements.txt
CUDA_VISIBLE_DEVICES=0 python main.py --zdims 32 32 32 --downsample 1 1 1 --nonlin elu --skip --blocks-per-layer 4 --gated --freebits 0.5 --learn-top-prior --data-dep-init --seed 42 --dataset static_mnist

Dependencies include boilr (a framework for PyTorch) and multiobject (which provides multi-object datasets with PyTorch dataloaders).

Likelihood results

Log likelihood bounds on the test set (average over 4 random seeds).

dataset num layers -ELBO - log p(x)
[100 iws]
- log p(x)
[1000 iws]
binarized MNIST 3 82.14 79.47 79.24
binarized MNIST 6 80.74 78.65 78.52
binarized MNIST 12 80.50 78.50 78.30
multi-dSprites (0-2) 12 26.9 23.2
SVHN 15 4012 (1.88) 3973 (1.87)
CIFAR10 3 7651 (3.59) 7591 (3.56)
CIFAR10 6 7321 (3.44) 7268 (3.41)
CIFAR10 15 7128 (3.35) 7068 (3.32)
CelebA 20 20026 (2.35) 19913 (2.34)

Note:

  • Bits per dimension in brackets.
  • 'iws' stands for importance weighted samples. More samples means tighter log likelihood lower bound. The bound converges to the actual log likelihood as the number of samples goes to infinity [5]. Note that the model is always trained with the ELBO (1 sample).
  • Each pixel in the images is modeled independently. The likelihood is Bernoulli for binary images, and discretized mixture of logistics with 10 components [6] otherwise.
  • One day I'll get around to evaluating the IW bound on all datasets with 10000 samples.

Supported datasets

  • Statically binarized MNIST [7], see Hugo Larochelle's website http://www.cs.toronto.edu/~larocheh/public/datasets/
  • SVHN
  • CIFAR10
  • CelebA rescaled and cropped to 64x64 – see code for details. The path in experiment.data.DatasetLoader has to be modified
  • binary multi-dSprites: 64x64 RGB shapes (0 to 2) in each image

Samples

Binarized MNIST

MNIST samples

Multi-dSprites

multi-dSprites samples

SVHN

SVHN samples

CIFAR

CIFAR samples

CelebA

CelebA samples

Hierarchical representations

Here we try to visualize the representations learned by individual layers. We can get a rough idea of what's going on at layer i as follows:

  • Sample latent variables from all layers above layer i (Eq. 1).

  • With these variables fixed, take S conditional samples at layer i (Eq. 2). Note that they are all conditioned on the same samples. These correspond to one row in the images below.

  • For each of these samples (each small image in the images below), pick the mode/mean of the conditional distribution of each layer below (Eq. 3).

  • Finally, sample an image x given the latent variables (Eq. 4).

Formally:

                

where s = 1, ..., S denotes the sample index.

The equations above yield S sample images conditioned on the same values of z for layers i+1 to L. These S samples are shown in one row of the images below. Notice that samples from each row are almost identical when the variability comes from a low-level layer, as such layers mostly model local structure and details. Higher layers on the other hand model global structure, and we observe more and more variability in each row as we move to higher layers. When the sampling happens in the top layer (i = L), all samples are completely independent, even within a row.

Binarized MNIST: layers 4, 8, 10, and 12 (top layer)

MNIST layers 4   MNIST layers 8

MNIST layers 10   MNIST layers 12

SVHN: layers 4, 10, 13, and 15 (top layer)

SVHN layers 4   SVHN layers 10

SVHN layers 13   SVHN layers 15

CIFAR: layers 3, 7, 10, and 15 (top layer)

CIFAR layers 3   CIFAR layers 7

CIFAR layers 10   CIFAR layers 15

CelebA: layers 6, 11, 16, and 20 (top layer)

CelebA layers 6

CelebA layers 11

CelebA layers 16

CelebA layers 20

Multi-dSprites: layers 3, 7, 10, and 12 (top layer)

MNIST layers 4   MNIST layers 8

MNIST layers 10   MNIST layers 12

Experimental details

I did not perform an extensive hyperparameter search, but this worked pretty well:

  • Downsampling by a factor of 2 in the beginning of inference. After that, activations are downsampled 4 times for 64x64 images (CelebA and multi-dSprites), and 3 times otherwise. The spatial size of the final feature map is always 2x2. Between these downsampling steps there is approximately the same number of stochastic layers.
  • 4 residual blocks between stochastic layers. Haven't tried with more than 4 though, as models become quite big and we get diminishing returns.
  • The deterministic parts of bottom-up and top-down architecture are (almost) perfectly mirrored for simplicity.
  • Stochastic layers have spatial random variables, and the number of rvs per "location" (i.e. number of channels of the feature map after sampling from a layer) is 32 in all layers.
  • All other feature maps in deterministic paths have 64 channels.
  • Skip connections in the generative model (--skip).
  • Gated residual blocks (--gated).
  • Learned prior of the top layer (--learn-top-prior).
  • A form of data-dependent initialization of weights (--data-dep-init). See code for details.
  • freebits=1.0 in experiments with more than 6 stochastic layers, and 0.5 for smaller models.
  • For everything else, see _add_args() in experiment/experiment_manager.py.

With these settings, the number of parameters is roughly 1M per stochastic layer. I tried to control for this by experimenting e.g. with half the number of layers but twice the number of residual blocks, but it looks like the number of stochastic layers is what matters the most.

References

[1] CK Sønderby, T Raiko, L Maaløe, SK Sønderby, O Winther. Ladder Variational Autoencoders, NIPS 2016

[2] L Maaløe, M Fraccaro, V Liévin, O Winther. BIVA: A Very Deep Hierarchy of Latent Variables for Generative Modeling, NeurIPS 2019

[3] DP Kingma, T Salimans, R Jozefowicz, X Chen, I Sutskever, M Welling. Improved Variational Inference with Inverse Autoregressive Flow, NIPS 2016

[4] I Higgins, L Matthey, A Pal, C Burgess, X Glorot, M Botvinick, S Mohamed, A Lerchner. beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework, ICLR 2017

[5] Y Burda, RB Grosse, R Salakhutdinov. Importance Weighted Autoencoders, ICLR 2016

[6] T Salimans, A Karpathy, X Chen, DP Kingma. PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications, ICLR 2017

[7] H Larochelle, I Murray. The neural autoregressive distribution estimator, AISTATS 2011

Owner
Andrea Dittadi
PhD student at DTU Compute | representation learning, deep generative models
Andrea Dittadi
DTCN IJCAI - Sequential prediction learning framework and algorithm

DTCN This is the implementation of our paper "Sequential Prediction of Social Me

Bobby 2 Jan 24, 2022
A simple, high level, easy-to-use open source Computer Vision library for Python.

ZoomVision : Slicing Aid Detection A simple, high level, easy-to-use open source Computer Vision library for Python. Installation Installing dependenc

Nurettin Sinanoğlu 2 Mar 04, 2022
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

Wenhao Wang 89 Jan 02, 2023
The Power of Scale for Parameter-Efficient Prompt Tuning

The Power of Scale for Parameter-Efficient Prompt Tuning Implementation of soft embeddings from https://arxiv.org/abs/2104.08691v1 using Pytorch and H

Kip Parker 208 Dec 30, 2022
Face and other object detection using OpenCV and ML Yolo

Object-and-Face-Detection-Using-Yolo- Opencv and YOLO object and face detection is implemented. You only look once (YOLO) is a state-of-the-art, real-

Happy N. Monday 3 Feb 15, 2022
A highly modular PyTorch framework with a focus on Neural Architecture Search (NAS).

UniNAS A highly modular PyTorch framework with a focus on Neural Architecture Search (NAS). under development (which happens mostly on our internal Gi

Cognitive Systems Research Group 19 Nov 23, 2022
Article Reranking by Memory-enhanced Key Sentence Matching for Detecting Previously Fact-checked Claims.

MTM This is the official repository of the paper: Article Reranking by Memory-enhanced Key Sentence Matching for Detecting Previously Fact-checked Cla

ICTMCG 13 Sep 17, 2022
University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN

Music-Sentiment-Transfer University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN Poster: Music Sentiment Transfer

Miles Sigel 2 Jan 24, 2022
g9.py - Torch interactive graphics

g9.py - Torch interactive graphics A Torch toy in the browser. Demo at https://srush.github.io/g9py/ This is a shameless copy of g9.js, written in Pyt

Sasha Rush 13 Nov 16, 2022
MonoScene: Monocular 3D Semantic Scene Completion

MonoScene: Monocular 3D Semantic Scene Completion MonoScene: Monocular 3D Semantic Scene Completion] [arXiv + supp] | [Project page] Anh-Quan Cao, Rao

298 Jan 08, 2023
This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUNet This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation Usage

1.4k Jan 04, 2023
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
FedML: A Research Library and Benchmark for Federated Machine Learning

FedML: A Research Library and Benchmark for Federated Machine Learning 📄 https://arxiv.org/abs/2007.13518 News 2021-02-01 (Award): #NeurIPS 2020# Fed

FedML-AI 2.3k Jan 08, 2023
Captcha-tensorflow - Image Captcha Solving Using TensorFlow and CNN Model. Accuracy 90%+

Captcha Solving Using TensorFlow Introduction Solve captcha using TensorFlow. Learn CNN and TensorFlow by a practical project. Follow the steps, run t

Jackon Yang 869 Jan 06, 2023
The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization

PRIMER The official code for PRIMER: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization. PRIMER is a pre-trained model for mu

AI2 111 Dec 18, 2022
Official implementation for "QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation" (CVPR 2022)

QS-Attn: Query-Selected Attention for Contrastive Learning in I2I Translation (CVPR2022) https://arxiv.org/abs/2203.08483 Unpaired image-to-image (I2I

Xueqi Hu 50 Dec 16, 2022
(AAAI2020)Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing

Grapy-ML: Graph Pyramid Mutual Learning for Cross-dataset Human Parsing This repository contains pytorch source code for AAAI2020 oral paper: Grapy-ML

54 Aug 04, 2022
An imperfect information game is a type of game with asymmetric information

DecisionHoldem An imperfect information game is a type of game with asymmetric information. Compared with perfect information game, imperfect informat

Decision AI 25 Dec 23, 2022
Moiré Attack (MA): A New Potential Risk of Screen Photos [NeurIPS 2021]

Moiré Attack (MA): A New Potential Risk of Screen Photos [NeurIPS 2021] This repository is the official implementation of Moiré Attack (MA): A New Pot

Dantong Niu 22 Dec 24, 2022
Code for SALT: Stackelberg Adversarial Regularization, EMNLP 2021.

SALT: Stackelberg Adversarial Regularization Code for Adversarial Regularization as Stackelberg Game: An Unrolled Optimization Approach, EMNLP 2021. R

Simiao Zuo 10 Jan 10, 2022