Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks

Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks Contributions A novel pairwise feature LSP to extract structural

31 Dec 06, 2022
Visual Adversarial Imitation Learning using Variational Models (VMAIL)

Visual Adversarial Imitation Learning using Variational Models (VMAIL) This is the official implementation of the NeurIPS 2021 paper. Project website

14 Nov 18, 2022
Speed-Test - You can check your intenet speed using this tool

Speed-Test Tool By Hez_X AVAILABLE ON : Termux & Kali linux & Ubuntu (Linux E

Hez-X 3 Feb 17, 2022
This repository contains the code for the paper ``Identifiable VAEs via Sparse Decoding''.

Sparse VAE This repository contains the code for the paper ``Identifiable VAEs via Sparse Decoding''. Data Sources The datasets used in this paper wer

Gemma Moran 17 Dec 12, 2022
[CVPR'21] Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild

IVOS-W Paper Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild Zhaoyun Yin, Jia Zheng, Weixin Luo, Shenhan Qian, Hanli

SVIP Lab 38 Dec 12, 2022
Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions

README Repository containing the code for the paper "Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions". Specifically, an

Yousef Emam 13 Nov 24, 2022
Fast methods to work with hydro- and topography data in pure Python.

PyFlwDir Intro PyFlwDir contains a series of methods to work with gridded DEM and flow direction datasets, which are key to many workflows in many ear

Deltares 27 Dec 07, 2022
Repository for Traffic Accident Benchmark for Causality Recognition (ECCV 2020)

Causality In Traffic Accident (Under Construction) Repository for Traffic Accident Benchmark for Causality Recognition (ECCV 2020) Overview Data Prepa

Tackgeun 21 Nov 20, 2022
The Python code for the paper A Hybrid Quantum-Classical Algorithm for Robust Fitting

About The Python code for the paper A Hybrid Quantum-Classical Algorithm for Robust Fitting The demo program was only tested under Conda in a standard

Anh-Dzung Doan 5 Nov 28, 2022
Deeplab-resnet-101 in Pytorch with Jaccard loss

Deeplab-resnet-101 Pytorch with Lovász hinge loss Train deeplab-resnet-101 with binary Jaccard loss surrogate, the Lovász hinge, as described in http:

Maxim Berman 95 Apr 15, 2022
Implementation of Hourglass Transformer, in Pytorch, from Google and OpenAI

Hourglass Transformer - Pytorch (wip) Implementation of Hourglass Transformer, in Pytorch. It will also contain some of my own ideas about how to make

Phil Wang 61 Dec 25, 2022
OpenL3: Open-source deep audio and image embeddings

OpenL3 OpenL3 is an open-source Python library for computing deep audio and image embeddings. Please refer to the documentation for detailed instructi

Music and Audio Research Laboratory - NYU 326 Jan 02, 2023
Codebase for Attentive Neural Hawkes Process (A-NHP) and Attentive Neural Datalog Through Time (A-NDTT)

Introduction Codebase for the paper Transformer Embeddings of Irregularly Spaced Events and Their Participants. This codebase contains two packages: a

Alan Yang 28 Dec 12, 2022
Implementation of BI-RADS-BERT & The Advantages of Section Tokenization.

BI-RADS BERT Implementation of BI-RADS-BERT & The Advantages of Section Tokenization. This implementation could be used on other radiology in house co

1 May 17, 2022
Python implementation of cover trees, near-drop-in replacement for scipy.spatial.kdtree

This is a Python implementation of cover trees, a data structure for finding nearest neighbors in a general metric space (e.g., a 3D box with periodic

Patrick Varilly 28 Nov 25, 2022
PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features

PRIN/SPRIN: On Extracting Point-wise Rotation Invariant Features Overview This repository is the Pytorch implementation of PRIN/SPRIN: On Extracting P

Yang You 17 Mar 02, 2022
A simple code to convert image format and channel as well as resizing and renaming multiple images.

Rename-Resize-and-convert-multiple-images A simple code to convert image format and channel as well as resizing and renaming multiple images. This cod

Happy N. Monday 3 Feb 15, 2022
Human Dynamics from Monocular Video with Dynamic Camera Movements

Human Dynamics from Monocular Video with Dynamic Camera Movements Ri Yu, Hwangpil Park and Jehee Lee Seoul National University ACM Transactions on Gra

215 Jan 01, 2023
Official PyTorch Implementation of HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning (NeurIPS 2021 Spotlight)

[NeurIPS 2021 Spotlight] HELP: Hardware-adaptive Efficient Latency Prediction for NAS via Meta-Learning [Paper] This is Official PyTorch implementatio

42 Nov 01, 2022
Pairwise model for commonlit competition

Pairwise model for commonlit competition To run: - install requirements - create input directory with train_folds.csv and other competition data - cd

abhishek thakur 45 Aug 31, 2022