Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
FG-transformer-TTS Fine-grained style control in transformer-based text-to-speech synthesis

LST-TTS Official implementation for the paper Fine-grained style control in transformer-based text-to-speech synthesis. Submitted to ICASSP 2022. Audi

Li-Wei Chen 64 Dec 30, 2022
This is the repository for The Machine Learning Workshops, published by AI DOJO

This is the repository for The Machine Learning Workshops, published by AI DOJO. It contains all the workshop's code with supporting project files necessary to work through the code.

AI Dojo 12 May 06, 2022
Official implementation of "Watermarking Images in Self-Supervised Latent-Spaces"

🔍 Watermarking Images in Self-Supervised Latent-Spaces PyTorch implementation and pretrained models for the paper. For details, see Watermarking Imag

Meta Research 32 Dec 13, 2022
MiniSom is a minimalistic implementation of the Self Organizing Maps

MiniSom Self Organizing Maps MiniSom is a minimalistic and Numpy based implementation of the Self Organizing Maps (SOM). SOM is a type of Artificial N

Giuseppe Vettigli 1.2k Jan 03, 2023
RIFE - Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE - Real-Time Intermediate Flow Estimation for Video Frame Interpolation YouTube | BiliBili 16X interpolation results from two input images: Introd

旷视天元 MegEngine 28 Dec 09, 2022
High-quality implementations of standard and SOTA methods on a variety of tasks.

Uncertainty Baselines The goal of Uncertainty Baselines is to provide a template for researchers to build on. The baselines can be a starting point fo

Google 1.1k Dec 30, 2022
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
Kaggle G2Net Gravitational Wave Detection : 2nd place solution

Kaggle G2Net Gravitational Wave Detection : 2nd place solution

Hiroshechka Y 33 Dec 26, 2022
Decorators for maximizing memory utilization with PyTorch & CUDA

torch-max-mem This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and

Max Berrendorf 10 May 02, 2022
CAST: Character labeling in Animation using Self-supervision by Tracking

CAST: Character labeling in Animation using Self-supervision by Tracking (Published as a conference paper at EuroGraphics 2022) Note: The CAST paper c

15 Nov 18, 2022
Machine Learning Time-Series Platform

cesium: Open-Source Platform for Time Series Inference Summary cesium is an open source library that allows users to: extract features from raw time s

632 Dec 26, 2022
PyTorch implemention of ICCV'21 paper SGPA: Structure-Guided Prior Adaptation for Category-Level 6D Object Pose Estimation

SGPA: Structure-Guided Prior Adaptation for Category-Level 6D Object Pose Estimation This is the PyTorch implemention of ICCV'21 paper SGPA: Structure

Chen Kai 24 Dec 05, 2022
[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

CodingMan 45 Dec 12, 2022
All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

Daniel Bourke 3.4k Jan 07, 2023
Notes taking website build with Docker + Django + React.

Notes website. Try it in browser! / But how to run? Description. This is monorepository with notes website. Website provides web interface for creatin

Kirill Zhosul 2 Jul 27, 2022
Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis

Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Website | ICCV paper | arXiv | Twitter This repository contains the official i

Ajay Jain 73 Dec 27, 2022
E-RAFT: Dense Optical Flow from Event Cameras

E-RAFT: Dense Optical Flow from Event Cameras This is the code for the paper E-RAFT: Dense Optical Flow from Event Cameras by Mathias Gehrig, Mario Mi

Robotics and Perception Group 71 Dec 12, 2022
Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image

Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image This repository is an implementation of the method described in the following pap

21 Dec 15, 2022
NeurIPS 2021 Datasets and Benchmarks Track

AP-10K: A Benchmark for Animal Pose Estimation in the Wild Introduction | Updates | Overview | Download | Training Code | Key Questions | License Intr

AP-10K 82 Dec 11, 2022
Instance-based label smoothing for improving deep neural networks generalization and calibration

Instance-based Label Smoothing for Neural Networks Pytorch Implementation of the algorithm. This repository includes a new proposed method for instanc

Mohamed Maher 1 Aug 13, 2022