Official code for HH-VAEM

Overview

HH-VAEM

This repository contains the official Pytorch implementation of the Hierarchical Hamiltonian VAE for Mixed-type Data (HH-VAEM) model and the sampling-based feature acquisition technique presented in the paper Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo. HH-VAEM is a Hierarchical VAE model for mixed-type incomplete data that uses Hamiltonian Monte Carlo with automatic hyper-parameter tuning for improved approximate inference. The repository contains the implementation and the experiments provided in the paper.

Please, if you use this code, cite the preprint using:

@article{peis2022missing,
  title={Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo},
  author={Peis, Ignacio and Ma, Chao and Hern{\'a}ndez-Lobato, Jos{\'e} Miguel},
  journal={arXiv preprint arXiv:2202.04599},
  year={2022}
}

Instalation

The installation is straightforward using the following instruction, that creates a conda virtual environment named HH-VAEM using the provided file environment.yml:

conda env create -f environment.yml

Usage

Training

The project is developed in the recent research framework PyTorch Lightning. The HH-VAEM model is implemented as a LightningModule that is trained by means of a Trainer. A model can be trained by using:

# Example for training HH-VAEM on Boston dataset
python train.py --model HHVAEM --dataset boston --split 0

This will automatically download the boston dataset, split in 10 train/test splits and train HH-VAEM on the training split 0. Two folders will be created: data/ for storing the datasets and logs/ for model checkpoints and TensorBoard logs. The variable LOGDIR can be modified in src/configs.py to change the directory where these folders will be created (this might be useful for avoiding overloads in network file systems).

The following datasets are available:

  • A total of 10 UCI datasets: avocado, boston, energy, wine, diabetes, concrete, naval, yatch, bank or insurance.
  • The MNIST datasets: mnist or fashion_mnist.
  • More datasets can be easily added to src/datasets.py.

For each dataset, the corresponding parameter configuration must be added to src/configs.py.

The following models are also available (implemented in src/models/):

  • HHVAEM: the proposed model in the paper.
  • VAEM: the VAEM strategy presented in (Ma et al., 2020) with Gaussian encoder (without including the Partial VAE).
  • HVAEM: A Hierarchical VAEM with two layers of latent variables and a Gaussian encoder.
  • HMCVAEM: A VAEM that includes a tuned HMC sampler for the true posterior.
  • For MNIST datasets (non heterogeneous data), use HHVAE, VAE, HVAE and HMCVAE.

By default, the test stage will be executed at the end of the training stage. This can be cancelled with --test 0 for manually running the test using:

# Example for testing HH-VAEM on Boston dataset
python test.py --model HHVAEM --dataset boston --split 0

which will load the trained model to be tested on the boston test split number 0. Once all the splits are tested, the average results can be obtained using the script in the run/ folder:

# Example for obtaining the average test results with HH-VAEM on Boston dataset
python test_splits.py --model HHVAEM --dataset boston

Experiments

The experiments in the paper can be executed using:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning.py --model HHVAEM --dataset boston --method mi --split 0

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist --split 0

Once this is executed on all the splits, you can plot the SAIA error curves or obtain the average OoD metrics using the scripts in the run/ folder:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning_plots.py --models VAEM HHVAEM --dataset boston

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood_splits.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist


Help

Use the --help option for documentation on the usage of any of the mentioned scripts.

Contributors

Ignacio Peis
Chao Ma
José Miguel Hernández-Lobato

Contact

For further information: [email protected]

Owner
Ignacio Peis
PhD student at UC3M \\ Visitor at the Machine Learning Group, CBL, University of Cambridge
Ignacio Peis
An open source framework that provides a simple, universal API for building distributed applications. Ray is packaged with RLlib, a scalable reinforcement learning library, and Tune, a scalable hyperparameter tuning library.

Ray provides a simple, universal API for building distributed applications. Ray is packaged with the following libraries for accelerating machine lear

23.3k Dec 31, 2022
Simulation of early COVID-19 using SIR model and variants (SEIR ...).

COVID-19-simulation Simulation of early COVID-19 using SIR model and variants (SEIR ...). Made by the Laboratory of Sustainable Life Assessment (GYRO)

José Paulo Pereira das Dores Savioli 1 Nov 17, 2021
Avocado hass time series vs predict price

AVOCADO HASS TIME SERIES VÀ PREDICT PRICE Trước khi vào Heroku muốn giao diện đẹp mọi người chuyển giúp mình theo hình bên dưới https://avocado-hass.h

hieulmsc 3 Dec 18, 2021
机器学习检测webshell

ai-webshell-detect 机器学习检测webshell,利用textcnn+简单二分类网络,基于keras,花了七天 检测原理: 从文件熵 文件长度 文件语句提取出特征,然后文件熵与长度送入二分类网络,文件语句送入textcnn 项目原理,介绍,怎么做出来的

Huoji's 56 Dec 14, 2022
In this Repo a simple Sklearn Model will be trained and pushed to MLFlow

SKlearn_to_MLFLow In this Repo a simple Sklearn Model will be trained and pushed to MLFlow Install This Repo is based on poetry python3 -m venv .venv

1 Dec 13, 2021
Code for the TCAV ML interpretability project

Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV) Been Kim, Martin Wattenberg, Justin Gilmer, C

552 Dec 27, 2022
Predict profitability of trades based on indicator buy / sell signals

Predict profitability of trades based on indicator buy / sell signals Trade profitability analysis for trades based on various indicators signals: MAC

Tomasz Porzycki 1 Dec 15, 2021
Empyrial is a Python-based open-source quantitative investment library dedicated to financial institutions and retail investors

By Investors, For Investors. Want to read this in Chinese? Click here Empyrial is a Python-based open-source quantitative investment library dedicated

Santosh 640 Dec 31, 2022
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective. 10x Larger Models 10x Faster Trainin

Microsoft 8.4k Dec 30, 2022
Predicting diabetes over a five year period using logistic regression and the Pima First-Nation dataset

Diabetes This script uses the Pima First Nations dataset to create a model to predict whether or not an individual will develop Diabetes Mellitus Type

1 Mar 28, 2022
Applied Machine Learning for Graduate Program in Computer Science (PPGCC)

Applied Machine Learning for Graduate Program in Computer Science (PPGCC) - Federal University of Santa Catarina

Jônatas Negri Grandini 1 Dec 22, 2021
SynapseML - an open source library to simplify the creation of scalable machine learning pipelines

Synapse Machine Learning SynapseML (previously MMLSpark) is an open source library to simplify the creation of scalable machine learning pipelines. Sy

Microsoft 3.9k Dec 30, 2022
CorrProxies - Optimizing Machine Learning Inference Queries with Correlative Proxy Models

CorrProxies - Optimizing Machine Learning Inference Queries with Correlative Proxy Models

ZhihuiYangCS 8 Jun 07, 2022
Covid-polygraph - a set of Machine Learning-driven fact-checking tools

Covid-polygraph, a set of Machine Learning-driven fact-checking tools that aim to address the issue of misleading information related to COVID-19.

1 Apr 22, 2022
A Python implementation of the Robotics Toolbox for MATLAB

Robotics Toolbox for Python A Python implementation of the Robotics Toolbox for MATLAB® GitHub repository Documentation Wiki (examples and details) Sy

Peter Corke 1.2k Jan 07, 2023
A repository for collating all the resources such as articles, blogs, papers, and books related to Bayesian Statistics.

A repository for collating all the resources such as articles, blogs, papers, and books related to Bayesian Statistics.

Aayush Malik 80 Dec 12, 2022
Machine Learning for RC Cars

Suiron Machine Learning for RC Cars Prediction visualization (green = actual, blue = prediction) Click the video below to see it in action! Dependenci

Kendrick Tan 706 Jan 02, 2023
Dieses Projekt ermöglicht es den Smartmeter der EVN (Netz Niederösterreich) über die Kundenschnittstelle auszulesen.

SmartMeterEVN Dieses Projekt ermöglicht es den Smartmeter der EVN (Netz Niederösterreich) über die Kundenschnittstelle auszulesen. Smart Meter werden

greenMike 43 Dec 04, 2022
XAI - An eXplainability toolbox for machine learning

XAI - An eXplainability toolbox for machine learning XAI is a Machine Learning library that is designed with AI explainability in its core. XAI contai

The Institute for Ethical Machine Learning 875 Dec 27, 2022
The easy way to combine mlflow, hydra and optuna into one machine learning pipeline.

mlflow_hydra_optuna_the_easy_way The easy way to combine mlflow, hydra and optuna into one machine learning pipeline. Objective TODO Usage 1. build do

shibuiwilliam 9 Sep 09, 2022