Hierarchical Few-Shot Generative Models

Overview

Hierarchical Few-Shot Generative Models

Giorgio Giannone, Ole Winther

This repo contains code and experiments for the paper Hierarchical Few-Shot Generative Models.


Settings

Clone the repo:

git clone https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models
cd hierarchical-few-shot-generative-models

Create and activate the conda env:

conda env create -f environment.yml
conda activate hfsgm

The code has been tested on Ubuntu 18.04, Python 3.6 and CUDA 11.3

We use wandb for visualization. The first time you run the code you will need to login.

Data

We provide preprocessed Omniglot dataset.

From the main folder, copy the data in data/omniglot_ns/:

wget https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models/releases/download/Omniglot/omni_train_val_test.pkl

For CelebA you need to download the dataset from here.

Dataset

In dataset we provide utilities to process and augment datasets in the few-shot setting. Each dataset is a large collection of small sets. Sets can be created dynamically. The dataset/base.py file collects basic info about the datasets. For binary datasets (omniglot_ns.py) we augment using flipping and rotations. For RGB datasets (celeba.py) we use only flipping.

Experiment

In experiment we implement scripts for model evaluation, experiments and visualizations.

  • attention.py - visualize attention weights and heads for models with learnable aggregations (LAG).
  • cardinality.py - compute ELBOs for different input set size: [1, 2, 5, 10, 20].
  • classifier_mnist.py - few-shot classifiers on MNIST.
  • kl_layer.py - compute KL over z and c for each layer in latent space.
  • marginal.py - compute approximate log-marginal likelihood with 1K importance samples.
  • refine_vis.py - visualize refined samples.
  • sampling_rgb.py - reconstruction, conditional, refined, unconditional sampling for RGB datasets.
  • sampling_transfer.py - reconstruction, conditional, refined, unconditional sampling on transfer datasets.
  • sampling.py - reconstruction, conditional, refined, unconditional sampling for binary datasets.
  • transfer.py - compute ELBOs on MNIST, DoubleMNIST, TripleMNIST.

Model

In model we implement baselines and model variants.

  • base.py - base class for all the models.
  • vae.py - Variational Autoencoder (VAE).
  • ns.py - Neural Statistician (NS).
  • tns.py - NS with learnable aggregation (NS-LAG).
  • cns.py - NS with convolutional latent space (CNS).
  • ctns.py - CNS with learnable aggregation (CNS-LAG).
  • hfsgm.py - Hierarchical Few-Shot Generative Model (HFSGM).
  • thfsgm.py - HFSGM with learnable aggregation (HFSGM-LAG).
  • chfsgm.py - HFSGM with convolutional latent space (CHFSGM).
  • cthfsgm.py - CHFSGM with learnable aggregation (CHFSGM-LAG).

Script

Scripts used for training the models in the paper.

To run a CNS on Omniglot:

sh script/main_cns.sh GPU_NUMBER omniglot_ns

Train a model

To train a generic model run:

python main.py --name {VAE, NS, CNS, CTNS, CHFSGM, CTHFSGM} \
               --model {vae, ns, cns, ctns, chfsgm, cthfsgm} \
               --augment \
               --dataset omniglot_ns \
               --likelihood binary \
               --hidden-dim 128 \
               --c-dim 32 \
               --z-dim 32 \
               --output-dir /output \
               --alpha-step 0.98 \
               --alpha 2 \
               --adjust-lr \
               --scheduler plateau \
               --sample-size {2, 5, 10} \
               --sample-size-test {2, 5, 10} \
               --num-classes 1 \
               --learning-rate 1e-4 \
               --epochs 400 \
               --batch-size 100 \
               --tag (optional string)

If you do not want to save logs, use the flag --dry_run. This flag will call utils/trainer_dry.py instead of trainer.py.


Acknowledgments

A lot of code and ideas borrowed from:

You might also like...
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch
Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Cross Transformers - Pytorch (wip) Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch Install $ pip install cross-t

Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation
[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation

Few-shot 3D Point Cloud Semantic Segmentation Created by Na Zhao from National University of Singapore Introduction This repository contains the PyTor

Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

The implementation of PEMP in paper
The implementation of PEMP in paper "Prior-Enhanced Few-Shot Segmentation with Meta-Prototypes"

Prior-Enhanced network with Meta-Prototypes (PEMP) This is the PyTorch implementation of PEMP. Overview of PEMP Meta-Prototypes & Adaptive Prototypes

Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Releases(Omniglot)
Owner
Giorgio Giannone
Science is built up with data, as a house is with stones. But a collection of data is no more a science than a heap of stones is a house. (J.H. Poincaré)
Giorgio Giannone
A code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Vanderhaeghe, and Yotam Gingold from SIGGRAPH Asia 2020.

A Benchmark for Rough Sketch Cleanup This is the code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Va

33 Dec 18, 2022
A Next Generation ConvNet by FaceBookResearch Implementation in PyTorch(Original) and TensorFlow.

ConvNeXt A Next Generation ConvNet by FaceBookResearch Implementation in PyTorch(Original) and TensorFlow. A FacebookResearch Implementation on A Conv

Raghvender 2 Feb 14, 2022
Deep Ensembling with No Overhead for either Training or Testing: The All-Round Blessings of Dynamic Sparsity

[ICLR 2022] Deep Ensembling with No Overhead for either Training or Testing: The All-Round Blessings of Dynamic Sparsity by Shiwei Liu, Tianlong Chen, Zahra Atashgahi, Xiaohan Chen, Ghada Sokar, Elen

VITA 18 Dec 31, 2022
The aim of the game, as in the original one, is to find a specific image from a group of different images of a person's face

GUESS WHO Main Links: [Github] [App] Related Links: [CLIP] [Celeba] The aim of the game, as in the original one, is to find a specific image from a gr

Arnau - DIMAI 3 Jan 04, 2022
RLDS stands for Reinforcement Learning Datasets

RLDS RLDS stands for Reinforcement Learning Datasets and it is an ecosystem of tools to store, retrieve and manipulate episodic data in the context of

Google Research 135 Jan 01, 2023
FeTaQA: Free-form Table Question Answering

FeTaQA: Free-form Table Question Answering FeTaQA is a Free-form Table Question Answering dataset with 10K Wikipedia-based {table, question, free-form

Language, Information, and Learning at Yale 40 Dec 13, 2022
Parameter Efficient Deep Probabilistic Forecasting

PEDPF Parameter Efficient Deep Probabilistic Forecasting (PEDPF) is a repository containing code to run experiments for several deep learning based pr

Olivier Sprangers 10 Jun 13, 2022
Resources for the "Evaluating the Factual Consistency of Abstractive Text Summarization" paper

Evaluating the Factual Consistency of Abstractive Text Summarization Authors: Wojciech Kryściński, Bryan McCann, Caiming Xiong, and Richard Socher Int

Salesforce 165 Dec 21, 2022
A modular, open and non-proprietary toolkit for core robotic functionalities by harnessing deep learning

A modular, open and non-proprietary toolkit for core robotic functionalities by harnessing deep learning Website • About • Installation • Using OpenDR

OpenDR 304 Dec 28, 2022
NICE-GAN — Official PyTorch Implementation Reusing Discriminators for Encoding: Towards Unsupervised Image-to-Image Translation

NICE-GAN-pytorch - Official PyTorch implementation of NICE-GAN: Reusing Discriminators for Encoding: Towards Unsupervised Image-to-Image Translation

Runfa Chen 208 Nov 25, 2022
:hot_pepper: R²SQL: "Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing." (AAAI 2021)

R²SQL The PyTorch implementation of paper Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing. (AAAI 2021) Requirement

huybery 60 Dec 31, 2022
SBINN: Systems-biology informed neural network

SBINN: Systems-biology informed neural network The source code for the paper M. Daneker, Z. Zhang, G. E. Karniadakis, & L. Lu. Systems biology: Identi

Lu Group 15 Nov 19, 2022
Code for "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clouds", CVPR 2021

PV-RAFT This repository contains the PyTorch implementation for paper "PV-RAFT: Point-Voxel Correlation Fields for Scene Flow Estimation of Point Clou

Yi Wei 43 Dec 05, 2022
A pytorch implementation of faster RCNN detection framework (Use detectron2, it's a masterpiece)

Notice(2019.11.2) This repo was built back two years ago when there were no pytorch detection implementation that can achieve reasonable performance.

Ruotian(RT) Luo 1.8k Jan 01, 2023
Avatarify Python - Avatars for Zoom, Skype and other video-conferencing apps.

Avatarify Python - Avatars for Zoom, Skype and other video-conferencing apps.

Ali Aliev 15.3k Jan 05, 2023
Wafer Fault Detection using MlOps Integration

Wafer Fault Detection using MlOps Integration This is an end to end machine learning project with MlOps integration for predicting the quality of wafe

Sethu Sai Medamallela 0 Mar 11, 2022
PyTorch implementation of paper "IBRNet: Learning Multi-View Image-Based Rendering", CVPR 2021.

IBRNet: Learning Multi-View Image-Based Rendering PyTorch implementation of paper "IBRNet: Learning Multi-View Image-Based Rendering", CVPR 2021. IBRN

Google Interns 371 Jan 03, 2023
TensorFlow implementation of ENet

TensorFlow-ENet TensorFlow implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation. This model was tested on th

Kwotsin 255 Oct 17, 2022
This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive Selective Coding)

HCSC: Hierarchical Contrastive Selective Coding This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive

YUANFAN GUO 111 Dec 20, 2022
Automatic differentiation with weighted finite-state transducers.

GTN: Automatic Differentiation with WFSTs Quickstart | Installation | Documentation What is GTN? GTN is a framework for automatic differentiation with

100 Dec 29, 2022