Pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks."

Overview

alpha-GAN

Unofficial pytorch implementation of Rosca, Mihaela, et al. "Variational Approaches for Auto-Encoding Generative Adversarial Networks." arXiv preprint arXiv:1706.04987 (2017).

I've got visually reasonable results on CIFAR-10 (see notebook). As the authors state, alpha-GAN is sensitive to changes in the network architectures. It seems important to keep batch normalization out of the code discriminator (C).

Deviations From The Paper

In the original paper (v1 on arXiv), prior and posterior terms are swapped in the code discriminator loss (equations 16 and 17 in Algorithm 1). Authors have confirmed.

Algorithm 1 in the paper is vague as to how each network should be updated; it doesn't account for SGD. The authors have confirmed that each of the four networks is updated separately in their experiments. However, in this implementation, encoder and generator (E and G networks) are updated jointly and share an optimizer. It may be worth revisiting the sequence and separation of optimizers.

This implementation adds the latent space cycle loss alluded to in the paper via an optional hyperparameter z_lambd. When z_lambd is nonzero, generated and reconstructed x will be run through the encoder and compared to the original sampled and encoded z.

Basic Usage

from alphagan import AlphaGAN

E, G, D, C = ... #torch.nn.Module

model = AlphaGAN(E, G, D, C, lambd=10, latent_dim=128)
if use_gpu:
  model = model.cuda()

X_train, X_valid = ... #torch.utils.data.DataSet

train_loader, valid_loader = ... #torch.utils.data.DataLoader

model.fit(train_loader, valid_loader, n_iter=(2,1,1), n_epochs=4, log_fn=print)

# encode and reconstruct
z_valid, x_recon = model(X_valid[:batch_size])

# sample from the generative model
z, x_gen = model(batch_size, mode='sample')

Supply any torch.nn.Module encoder, generator, discriminator, and code discriminator at construction and any torch.optim.Optimizer constructors and torch.utils.DataLoader to fit().

Examples

alphagan/examples/CIFAR.ipynb

Progress Bars

Install tqdm for progress bars. To get working nested progress bars in jupyter notebooks: pip install -e git+https://github.com/dvm-shlee/[email protected]#egg=tqdm

Owner
Victor Shepardson
Victor Shepardson
Machine Learning Privacy Meter: A tool to quantify the privacy risks of machine learning models with respect to inference attacks, notably membership inference attacks

ML Privacy Meter Machine learning is playing a central role in automated decision making in a wide range of organization and service providers. The da

Data Privacy and Trustworthy Machine Learning Research Lab 357 Jan 06, 2023
Understanding Hyperdimensional Computing for Parallel Single-Pass Learning

Understanding Hyperdimensional Computing for Parallel Single-Pass Learning Authors: Tao Yu* Yichi Zhang* Zhiru Zhang Christopher De Sa *: Equal Contri

Cornell RelaxML 4 Sep 08, 2022
Use .csv files to record, play and evaluate motion capture data.

Purpose These scripts allow you to record mocap data to, and play from .csv files. This approach facilitates parsing of body movement data in statisti

21 Dec 12, 2022
AI that generate music

PianoGPT ai that generate music try it here https://share.streamlit.io/annasajkh/pianogpt/main/main.py or here https://huggingface.co/spaces/Annas/Pia

Annas 28 Nov 27, 2022
[v1 (ISBI'21) + v2] MedMNIST: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification

MedMNIST Project (Website) | Dataset (Zenodo) | Paper (arXiv) | MedMNIST v1 (ISBI'21) Jiancheng Yang, Rui Shi, Donglai Wei, Zequan Liu, Lin Zhao, Bili

683 Dec 28, 2022
Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-like Documents.

Value Retrieval with Arbitrary Queries for Form-like Documents Introduction Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-

Salesforce 13 Sep 15, 2022
Final project code: Implementing BicycleGAN, for CIS680 FA21 at University of Pennsylvania

680 Final Project: BicycleGAN Haoran Tang Instructions 1. Training To train the network, please run train.py. Change hyper-parameters and folder paths

Haoran Tang 0 Apr 22, 2022
RetinaFace: Deep Face Detection Library in TensorFlow for Python

RetinaFace is a deep learning based cutting-edge facial detector for Python coming with facial landmarks.

Sefik Ilkin Serengil 512 Dec 29, 2022
Six - a Python 2 and 3 compatibility library

Six is a Python 2 and 3 compatibility library. It provides utility functions for smoothing over the differences between the Python versions with the g

Benjamin Peterson 919 Dec 28, 2022
Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Trevor Ablett*, Bryan Chan*,

STARS Laboratory 8 Sep 14, 2022
Cleaned test data list of DukeMTMC-reID, ICCV2021

Cleaned DukeMTMC-reID Cleaned data list of DukeMTMC-reID released with our paper accepted by ICCV 2021: Learning Instance-level Spatial-Temporal Patte

14 Feb 19, 2022
yolov5 deepsort 行人 车辆 跟踪 检测 计数

yolov5 deepsort 行人 车辆 跟踪 检测 计数 实现了 出/入 分别计数。 默认是 南/北 方向检测,若要检测不同位置和方向,可在 main.py 文件第13行和21行,修改2个polygon的点。 默认检测类别:行人、自行车、小汽车、摩托车、公交车、卡车。 检测类别可在 detect

554 Dec 30, 2022
a reimplementation of Holistically-Nested Edge Detection in PyTorch

pytorch-hed This is a personal reimplementation of Holistically-Nested Edge Detection [1] using PyTorch. Should you be making use of this work, please

Simon Niklaus 375 Dec 06, 2022
Implementation of the master's thesis "Temporal copying and local hallucination for video inpainting".

Temporal copying and local hallucination for video inpainting This repository contains the implementation of my master's thesis "Temporal copying and

David Álvarez de la Torre 1 Dec 02, 2022
Dynamic Graph Event Detection

DyGED Dynamic Graph Event Detection Get Started pip install -r requirements.txt TODO Paper link to arxiv, and how to cite. Twitter Weather dataset tra

Mert Koşan 3 May 09, 2022
A light-weight image labelling tool for Python designed for creating segmentation data sets.

An image labelling tool for creating segmentation data sets, for Django and Flask.

117 Nov 21, 2022
AI-generated-characters for Learning and Wellbeing

AI-generated-characters for Learning and Wellbeing Click here for the full project page. This repository contains the source code for the paper AI-gen

MIT Media Lab 214 Jan 01, 2023
Dynamic Bottleneck for Robust Self-Supervised Exploration

Dynamic Bottleneck Introduction This is a TensorFlow based implementation for our paper on "Dynamic Bottleneck for Robust Self-Supervised Exploration"

Bai Chenjia 4 Nov 14, 2022
A Pytree Module system for Deep Learning in JAX

Treex A Pytree-based Module system for Deep Learning in JAX Intuitive: Modules are simple Python objects that respect Object-Oriented semantics and sh

Cristian Garcia 216 Dec 20, 2022
Code + pre-trained models for the paper Keeping Your Eye on the Ball Trajectory Attention in Video Transformers

Motionformer This is an official pytorch implementation of paper Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers. In this rep

Facebook Research 192 Dec 23, 2022