VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries

Related tags

Deep LearningVACA
Overview

VACA

Code repository for the paper "VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries (arXiv)". The implementation is based on Pytorch, Pytorch Geometric and Pytorch Lightning. The repository contains the necessary resources to run the experiments of the paper. Follow the instructions below to download the German dataset.

Installation

Create conda environment and activate it:

conda create --name vaca python=3.9 --no-default-packages
conda activate vaca 

Option 1: Import the conda environment

conda env create -f environment.yml

Option 2: Commands

conda install pip
pip install torch torchvision torchaudio
pip install pytorch-lightning
pip install -U scikit-learn
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install matplotlib
pip install seaborn

Note: The German dataset is not contained in this repository. The first time you try to train on the German dataset, you will get an error with instructions on how to download and store it. Please follow the instructions, such that the code runs smoothly.

Datasets

This repository contains 7 different SCMs: - ColliderSCM - MGraphSCM - ChainSCM - TriangleSCM - LoanSCM - AdultSCM - GermanSCM

Additionally, we provide the implementation of the first five SCMs with three different types of structural equations: linear (LIN), non-linear (NLIN) and non-additive (NADD). You can find the implementation of all the datasets inside the folder datasets. To create all datasets at once run python _create_data_toy.py (this is optional since the datasets will be created as needed on the fly).

How to create your custom Toy Datasets

We also provide a function to create custom ToySCM datasets. Here is an example of an SCM with 2 nodes

from datasets.toy import create_toy_dataset
from utils.distributions import *
dataset = create_toy_dataset(root_dir='./my_custom_datasets',
                             name='2graph',
                             eq_type='linear',
                             nodes_to_intervene=['x1'],
                             structural_eq={'x1': lambda u1: u1,
                                            'x2': lambda u2, x1: u2 + x1},
                             noises_distr={'x1': Normal(0,1),
                                           'x2': Normal(0,1)},
                             adj_edges={'x1': ['x2'],
                                        'x2': []},
                             split='train',
                             num_samples=5000,
                             likelihood_names='d_d',
                             lambda_=0.05)

Training

To train a model you need to execute the script main.py. For that, you need to specify three configuration files: - dataset_file: Specifies the dataset and the parameters of the dataset. You can overwrite the dataset parameters -d. - model_file: Specifies the model and the parameters of the model as well as the optimizer. You can overwrite the model parameters with -m and the optimizer parameters with -o. - trainer_file: Specifies the training parameters of the Trainer object from PyTorch Lightning.

For plotting results use --plots 1. For more information, run python main.py --help.

Examples

To train our VACA algorithm on each of the synthetic graphs with linear structural equations (default value in dataset_ ):

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml

You can also select a different SEM with the -d option and

  • for linear (LIN) equations -d equations_type=linear,
  • for non-linear (NLIN) equations -d equations_type=non-linear,
  • for non-additive (NADD) equation -d equations_type=non-additive.

For example, to train the triangle graph with non linear SEM:

python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml -d equations_type=non-linear

We can train our VACA algorithm on the German dataset:

python main.py --dataset_file _params/dataset_german.yaml --model_file _params/model_vaca.yaml

To run the CAREFL model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_carefl.yaml

To run the MultiCVAE model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_mcvae.yaml

How to load a trained model?

To load a trained model:

  • set the training flag to -i 0.
  • select configuration file of our training model, i.e. hparams_full.yaml
python main.py --yaml_file=PATH/hparams_full.yaml -i 0

Load a model and train/evaluate counterfactual fairness

Load your model and add the flag --eval_fair. For example:

python main.py --yaml_file=PATH/hparams_full.yaml -i 0 --eval_fair --show_results

TensorBoard visualization

You can track different metrics during (and after) training using TensorBoard. For example, if the root folder of the experiments is exper_test, we can run the following command in a terminal

tensorboard --logdir exper_test/   

to display the logs of all experiments contained in such folder. Then, we go to our favourite browser and go to http://localhost:6006/ to visualize all the results.

Owner
Pablo Sánchez-Martín
Ph.D. student at Max Planck Institute for Intelligence Systems
Pablo Sánchez-Martín
Finding Biological Plausibility for Adversarially Robust Features via Metameric Tasks

Adversarially-Robust-Periphery Code + Data from the paper "Finding Biological Plausibility for Adversarially Robust Features via Metameric Tasks" by A

Anne Harrington 2 Feb 07, 2022
Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces

This repository contains source code for the paper Combining Latent Space and Structured Kernels for Bayesian Optimization over Combinatorial Spaces a

9 Nov 21, 2022
Official PyTorch Implementation of Convolutional Hough Matching Networks, CVPR 2021 (oral)

Convolutional Hough Matching Networks This is the implementation of the paper "Convolutional Hough Matching Network" by J. Min and M. Cho. Implemented

Juhong Min 70 Nov 22, 2022
[ICCV2021] Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving

Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving

Xuanchi Ren 44 Dec 03, 2022
A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

jedibobo 3 Dec 28, 2022
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

🦩 Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

Phil Wang 630 Dec 28, 2022
BC3407-Group-5-Project - BC3407 Group Project With Python

BC3407-Group-5-Project As the world struggles to contain the ever-changing varia

1 Jan 26, 2022
Joint parameterization and fitting of stroke clusters

StrokeStrip: Joint Parameterization and Fitting of Stroke Clusters Dave Pagurek van Mossel1, Chenxi Liu1, Nicholas Vining1,2, Mikhail Bessmeltsev3, Al

Dave Pagurek 44 Dec 01, 2022
FAIR's research platform for object detection research, implementing popular algorithms like Mask R-CNN and RetinaNet.

Detectron is deprecated. Please see detectron2, a ground-up rewrite of Detectron in PyTorch. Detectron Detectron is Facebook AI Research's software sy

Facebook Research 25.5k Jan 07, 2023
A Python implementation of active inference for Markov Decision Processes

A Python package for simulating Active Inference agents in Markov Decision Process environments. Please see our companion preprint on arxiv for an ove

235 Dec 21, 2022
TensorFlow implementation for Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How

Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How TensorFlow implementation for Bayesian Modeling and Unce

Shen Lab at Texas A&M University 8 Sep 02, 2022
OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers (NeurIPS 2021)

OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers (NeurIPS 2021) This is an PyTorch implementation of OpenMatc

Vision and Learning Group 38 Dec 26, 2022
Lingvo is a framework for building neural networks in Tensorflow, particularly sequence models.

Lingvo is a framework for building neural networks in Tensorflow, particularly sequence models.

2.7k Jan 05, 2023
Keras-retinanet - Keras implementation of RetinaNet object detection.

Keras RetinaNet Keras implementation of RetinaNet object detection as described in Focal Loss for Dense Object Detection by Tsung-Yi Lin, Priya Goyal,

Fizyr 4.3k Jan 01, 2023
Tightness-aware Evaluation Protocol for Scene Text Detection

TIoU-metric Release on 27/03/2019. This repository is built on the ICDAR 2015 evaluation code. If you propose a better metric and require further eval

Yuliang Liu 206 Nov 18, 2022
Watch faces morph into each other with StyleGAN 2, StyleGAN, and DCGAN!

FaceMorpher FaceMorpher is an innovative project to get a unique face morph (or interpolation for geeks) on a website. Yes, this means you can see fac

Anish 9 Jun 24, 2022
Hitters Linear Regression - Hitters Linear Regression With Python

Hitters_Linear_Regression Kullanacağımız veri seti Carnegie Mellon Üniversitesi'

AyseBuyukcelik 2 Jan 26, 2022
LIVECell - A large-scale dataset for label-free live cell segmentation

LIVECell dataset This document contains instructions of how to access the data associated with the submitted manuscript "LIVECell - A large-scale data

Sartorius Corporate Research 112 Jan 07, 2023
Tensorflow 2 implementations of the C-SimCLR and C-BYOL self-supervised visual representation methods from "Compressive Visual Representations" (NeurIPS 2021)

Compressive Visual Representations This repository contains the source code for our paper, Compressive Visual Representations. We developed informatio

Google Research 30 Nov 23, 2022
Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network

DeepCDR Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network This work has been accepted to ECCB2020 and was also published in the

Qiao Liu 50 Dec 18, 2022