SimplEx - Explaining Latent Representations with a Corpus of Examples

Overview

SimplEx - Explaining Latent Representations with a Corpus of Examples

image

Code Author: Jonathan Crabbé ([email protected])

This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.

Installation

  1. Clone the repository
  2. Create a new virtual environment with Python 3.8
  3. Run the following command from the repository folder:
    pip install -r requirements.txt #install requirements

When the packages are installed, SimplEx can directly be used.

Toy example

Bellow, you can find a toy demonstration where we make a corpus decomposition of test examples representations. All the relevant code can be found in the file simplex.

from explainers.simplex import Simplex
from models.base import BlackBox

# Get the model and the examples
model = BlackBox() # Model should have the BlackBox interface
corpus_inputs = get_corpus() # A tensor of corpus inputs
test_inputs = get_test() # A set of inputs to explain

# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs) 
test_latents = model.latent_representation(test_inputs)

# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs, 
                  corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs, 
            test_latent_reps=test_latents,
            reg_factor=0)

# Get the weights of each corpus decomposition
weights = simplex.weights

We get a tensor weights that can be interpreted as follows: weights[i,c] = weight of corpus example c in the decomposition of example i.

We can get the importance of each corpus feature for the decomposition of a given example i in the following way:

# Compute the Integrated Jacobian for a particular example
i = 42
input_baseline = get_baseline() # Baseline tensor of the same shape as corpus_inputs
simplex.jacobian_projections(test_id=i, model=model,
                             input_baseline=input_baseline)

result = simplex.decompose(i)

We get a list result where each element of the list corresponds to a corpus example. This list is sorted by decreasing order of importance in the corpus decomposition. Each element of the list is a tuple structured as follows:

w_c, x_c, proj_jacobian_c = result[c]

Where w_c corresponds to the weight weights[i,c], x_c corresponds to corpus_inputs[c] and proj_jacobian is a tensor such that proj_jacobian_c[k] is the Projected Jacobian of feature k from corpus example c.

Reproducing the paper results

Reproducing MNIST Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.mnist -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing Prostate Cancer Approximation Quality Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Copy the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.prostate_cancer -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.prostate.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing Prostate Cancer Outlier Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Make sure that the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.prostate_cancer -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.prostate.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing MNIST Jacobian Projection Significance Experiment

  1. Run the following script
python -m experiments.mnist -experiment "jacobian_corruption" 

2.The resulting plots and data are saved here.

Reproducing MNIST Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.mnist -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing MNIST Influence Function Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.mnist -experiment "influence" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.influence.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Note: some problems can appear with the package Pytorch Influence Functions. If this is the case, please change calc_influence_function.py in the following way:

343: influences.append(tmp_influence) ==> influences.append(tmp_influence.cpu())
438: influences_meta['test_sample_index_list'] = sample_list ==> #influences_meta['test_sample_index_list'] = sample_list

Reproducing AR Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.time_series -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.ar.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing AR Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.time_series -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.ar.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Citing

If you use this code, please cite the associated paper:

Put citation here when ready
Owner
Jonathan Crabbé
I am currently doing a PhD in Explainable AI at the Department of Applied Mathematics and Theoretical Physics (DAMTP) of the University of Cambridge.
Jonathan Crabbé
A collection of models for image<->text generation in ACM MM 2021.

Bi-directional Image and Text Generation UMT-BITG (image & text generator) Unifying Multimodal Transformer for Bi-directional Image and Text Generatio

Multimedia Research 63 Oct 30, 2022
Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance

Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance Project Page | Paper | Data This repository contains an implementatio

Lior Yariv 521 Dec 30, 2022
Code repository of the paper Neural circuit policies enabling auditable autonomy published in Nature Machine Intelligence

Neural Circuit Policies Enabling Auditable Autonomy Online access via SharedIt Neural Circuit Policies (NCPs) are designed sparse recurrent neural net

8 Jan 07, 2023
Low-code/No-code approach for deep learning inference on devices

EzEdgeAI A concept project that uses a low-code/no-code approach to implement deep learning inference on devices. It provides a componentized framewor

On-Device AI Co., Ltd. 7 Apr 05, 2022
Neural style transfer as a class in PyTorch

pt-styletransfer Neural style transfer as a class in PyTorch Based on: https://github.com/alexis-jacq/Pytorch-Tutorials Adds: StyleTransferNet as a cl

Tyler Kvochick 31 Jun 27, 2022
This is the official pytorch implementation of AutoDebias, an automatic debiasing method for recommendation.

AutoDebias This is the official pytorch implementation of AutoDebias, a debiasing method for recommendation system. AutoDebias is proposed in the pape

Dong Hande 77 Nov 25, 2022
Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Jonas Köhler 893 Dec 28, 2022
Individual Treatment Effect Estimation

CAPE Individual Treatment Effect Estimation Run CAPE python train_causal.py --loop 10 -m cape_cau -d NI --i_t 1 Run a baseline model python train_cau

S. Deng 4 Sep 02, 2022
This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Developed By Google!

Machine Learning Hand Detector This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Dev

Popstar Idhant 3 Feb 25, 2022
An index of algorithms for learning causality with data

awesome-causality-algorithms An index of algorithms for learning causality with data. Please cite our survey paper if this index is helpful. @article{

Ruocheng Guo 2.3k Jan 08, 2023
A knowledge base construction engine for richly formatted data

Fonduer is a Python package and framework for building knowledge base construction (KBC) applications from richly formatted data. Note that Fonduer is

HazyResearch 386 Dec 05, 2022
Python periodic table module

elemenpy Hello! elements.py is a small Python periodic table module that is used for calling certain information about an element. Installation Instal

Eric Cheng 2 Dec 27, 2021
This repository contains PyTorch models for SpecTr (Spectral Transformer).

SpecTr: Spectral Transformer for Hyperspectral Pathology Image Segmentation This repository contains PyTorch models for SpecTr (Spectral Transformer).

Boxiang Yun 45 Dec 13, 2022
Interpretable-contrastive-word-mover-s-embedding

Interpretable-contrastive-word-mover-s-embedding Paper Datasets Here is a Dropbox link to the datasets used in the paper: https://www.dropbox.com/sh/n

0 Nov 02, 2021
PyTorch implementation of ECCV 2020 paper "Foley Music: Learning to Generate Music from Videos "

Foley Music: Learning to Generate Music from Videos This repo holds the code for the framework presented on ECCV 2020. Foley Music: Learning to Genera

Chuang Gan 30 Nov 03, 2022
MultiSiam: Self-supervised Multi-instance Siamese Representation Learning for Autonomous Driving

MultiSiam: Self-supervised Multi-instance Siamese Representation Learning for Autonomous Driving Code will be available soon. Motivation Architecture

Kai Chen 24 Apr 19, 2022
Like Dirt-Samples, but cleaned up

Clean-Samples Like Dirt-Samples, but cleaned up, with clear provenance and license info (generally a permissive creative commons licence but check the

TidalCycles 39 Nov 30, 2022
Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization.

Scene Graph Generation Object Detections Ground truth Scene Graph Generated Scene Graph In this visualization, woman sitting on rock is a zero-shot tr

Boris Knyazev 93 Dec 28, 2022
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
The open-source and free to use Python package miseval was developed to establish a standardized medical image segmentation evaluation procedure

miseval: a metric library for Medical Image Segmentation EVALuation The open-source and free to use Python package miseval was developed to establish

59 Dec 10, 2022