VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries

Related tags

Deep LearningVACA
Overview

VACA

Code repository for the paper "VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries (arXiv)". The implementation is based on Pytorch, Pytorch Geometric and Pytorch Lightning. The repository contains the necessary resources to run the experiments of the paper. Follow the instructions below to download the German dataset.

Installation

Create conda environment and activate it:

conda create --name vaca python=3.9 --no-default-packages
conda activate vaca 

Option 1: Import the conda environment

conda env create -f environment.yml

Option 2: Commands

conda install pip
pip install torch torchvision torchaudio
pip install pytorch-lightning
pip install -U scikit-learn
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install matplotlib
pip install seaborn

Note: The German dataset is not contained in this repository. The first time you try to train on the German dataset, you will get an error with instructions on how to download and store it. Please follow the instructions, such that the code runs smoothly.

Datasets

This repository contains 7 different SCMs: - ColliderSCM - MGraphSCM - ChainSCM - TriangleSCM - LoanSCM - AdultSCM - GermanSCM

Additionally, we provide the implementation of the first five SCMs with three different types of structural equations: linear (LIN), non-linear (NLIN) and non-additive (NADD). You can find the implementation of all the datasets inside the folder datasets. To create all datasets at once run python _create_data_toy.py (this is optional since the datasets will be created as needed on the fly).

How to create your custom Toy Datasets

We also provide a function to create custom ToySCM datasets. Here is an example of an SCM with 2 nodes

from datasets.toy import create_toy_dataset
from utils.distributions import *
dataset = create_toy_dataset(root_dir='./my_custom_datasets',
                             name='2graph',
                             eq_type='linear',
                             nodes_to_intervene=['x1'],
                             structural_eq={'x1': lambda u1: u1,
                                            'x2': lambda u2, x1: u2 + x1},
                             noises_distr={'x1': Normal(0,1),
                                           'x2': Normal(0,1)},
                             adj_edges={'x1': ['x2'],
                                        'x2': []},
                             split='train',
                             num_samples=5000,
                             likelihood_names='d_d',
                             lambda_=0.05)

Training

To train a model you need to execute the script main.py. For that, you need to specify three configuration files: - dataset_file: Specifies the dataset and the parameters of the dataset. You can overwrite the dataset parameters -d. - model_file: Specifies the model and the parameters of the model as well as the optimizer. You can overwrite the model parameters with -m and the optimizer parameters with -o. - trainer_file: Specifies the training parameters of the Trainer object from PyTorch Lightning.

For plotting results use --plots 1. For more information, run python main.py --help.

Examples

To train our VACA algorithm on each of the synthetic graphs with linear structural equations (default value in dataset_ ):

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml

You can also select a different SEM with the -d option and

  • for linear (LIN) equations -d equations_type=linear,
  • for non-linear (NLIN) equations -d equations_type=non-linear,
  • for non-additive (NADD) equation -d equations_type=non-additive.

For example, to train the triangle graph with non linear SEM:

python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml -d equations_type=non-linear

We can train our VACA algorithm on the German dataset:

python main.py --dataset_file _params/dataset_german.yaml --model_file _params/model_vaca.yaml

To run the CAREFL model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_carefl.yaml

To run the MultiCVAE model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_mcvae.yaml

How to load a trained model?

To load a trained model:

  • set the training flag to -i 0.
  • select configuration file of our training model, i.e. hparams_full.yaml
python main.py --yaml_file=PATH/hparams_full.yaml -i 0

Load a model and train/evaluate counterfactual fairness

Load your model and add the flag --eval_fair. For example:

python main.py --yaml_file=PATH/hparams_full.yaml -i 0 --eval_fair --show_results

TensorBoard visualization

You can track different metrics during (and after) training using TensorBoard. For example, if the root folder of the experiments is exper_test, we can run the following command in a terminal

tensorboard --logdir exper_test/   

to display the logs of all experiments contained in such folder. Then, we go to our favourite browser and go to http://localhost:6006/ to visualize all the results.

Owner
Pablo Sánchez-Martín
Ph.D. student at Max Planck Institute for Intelligence Systems
Pablo Sánchez-Martín
shufflev2-yolov5:lighter, faster and easier to deploy

shufflev2-yolov5: lighter, faster and easier to deploy. Evolved from yolov5 and the size of model is only 1.7M (int8) and 3.3M (fp16). It can reach 10+ FPS on the Raspberry Pi 4B when the input size

pogg 1.5k Jan 05, 2023
Cortex-compatible model server for Python and TensorFlow

Nucleus model server Nucleus is a model server for TensorFlow and generic Python models. It is compatible with Cortex clusters, Kubernetes clusters, a

Cortex Labs 14 Nov 27, 2022
FridaHookAppTool - Frida Hook App Tool With Python

FridaHookAppTool(以下是Hook mpaas框架的例子) mpaas移动开发框架ios端抓包hook脚本 使用方法:链接数据线,开启burp设置

13 Nov 30, 2022
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ( Zitnik Lab @ Harvard 44 Dec 07, 2022

Unbiased Learning To Rank Algorithms (ULTRA)

This is an Unbiased Learning To Rank Algorithms (ULTRA) toolbox, which provides a codebase for experiments and research on learning to rank with human annotated or noisy labels.

71 Dec 01, 2022
Code for the paper "Multi-task problems are not multi-objective"

Multi-Task problems are not multi-objective This is the code for the paper "Multi-Task problems are not multi-objective" in which we show that the com

Michael Ruchte 5 Aug 19, 2022
Aiming at the common training datsets split, spectrum preprocessing, wavelength select and calibration models algorithm involved in the spectral analysis process

Aiming at the common training datsets split, spectrum preprocessing, wavelength select and calibration models algorithm involved in the spectral analysis process, a complete algorithm library is esta

Fu Pengyou 50 Jan 07, 2023
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

Deformable 3D Convolution for Video Super-Resolution Pytorch implementation of l

Xinyi Ying 28 Dec 15, 2022
Research - dataset and code for 2016 paper Learning a Driving Simulator

the people's comma the paper Learning a Driving Simulator the comma.ai driving dataset 7 and a quarter hours of largely highway driving. Enough to tra

comma.ai 4.1k Jan 02, 2023
Code for the paper "Improved Techniques for Training GANs"

Status: Archive (code is provided as-is, no updates expected) improved-gan code for the paper "Improved Techniques for Training GANs" MNIST, SVHN, CIF

OpenAI 2.2k Jan 01, 2023
Computations and statistics on manifolds with geometric structures.

Geomstats Code Continuous Integration Code coverage (numpy) Code coverage (autograd, tensorflow, pytorch) Documentation Community NEWS: Geomstats is r

875 Dec 31, 2022
Understanding and Overcoming the Challenges of Efficient Transformer Quantization

Transformer Quantization This repository contains the implementation and experiments for the paper presented in Yelysei Bondarenko1, Markus Nagel1, Ti

83 Dec 30, 2022
PyTorch implementation of MuseMorphose, a Transformer-based model for music style transfer.

MuseMorphose This repository contains the official implementation of the following paper: Shih-Lun Wu, Yi-Hsuan Yang MuseMorphose: Full-Song and Fine-

Yating Music, Taiwan AI Labs 142 Jan 08, 2023
Fermi Problems: A New Reasoning Challenge for AI

Fermi Problems: A New Reasoning Challenge for AI Fermi Problems are questions whose answer is a number that can only be reasonably estimated as a prec

AI2 15 May 28, 2022
RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

184 Jan 04, 2023
Paddle implementation for "Cross-Lingual Word Embedding Refinement by ℓ1 Norm Optimisation" (NAACL 2021)

L1-Refinement Paddle implementation for "Cross-Lingual Word Embedding Refinement by ℓ1 Norm Optimisation" (NAACL 2021) 🙈 A more detailed readme is co

Lincedo Lab 4 Jun 09, 2021
TinyML Cookbook, published by Packt

TinyML Cookbook This is the code repository for TinyML Cookbook, published by Packt. Author: Gian Marco Iodice Publisher: Packt About the book This bo

Packt 93 Dec 29, 2022
Official NumPy Implementation of Deep Networks from the Principle of Rate Reduction (2021)

Deep Networks from the Principle of Rate Reduction This repository is the official NumPy implementation of the paper Deep Networks from the Principle

Ryan Chan 49 Dec 16, 2022
A Pytorch Implementation of [Source data‐free domain adaptation of object detector through domain

A Pytorch Implementation of Source data‐free domain adaptation of object detector through domain‐specific perturbation Please follow Faster R-CNN and

1 Dec 25, 2021
Painting app using Python machine learning and vision technology.

AI Painting App We are making an app that will track our hand and helps us to draw from that. We will be using the advance knowledge of Machine Learni

Badsha Laskar 3 Oct 03, 2022