Fastshap: A fast, approximate shap kernel

Related tags

Deep Learningfastshap
Overview

fastshap: A fast, approximate shap kernel

fastshap was designed to be:

  • Fast Calculating shap values can take an extremely long time. fastshap utilizes inner and outer batch assignments to keep the calculations inside vectorized operations as often as it can.
  • Used on Tabular Data Can accept numpy arrays or pandas DataFrames, and can handle categorical variables natively. As of right now, only 1 dimensional outputs are accepted.

WARNING This package specifically offers a kernel explainer, which can calculate approximate shap values of f(X) towards y for any function f. Much faster shap solutions are available specifically for gradient boosted trees.

Installation

This package can be installed using either pip or conda, through conda-forge:

# Using pip
$ pip install fastshap --no-cache-dir

You can also download the latest development version from this repository. If you want to install from github with conda, you must first run conda install pip git.

$ pip install git+https://github.com/AnotherSamWilson/fastshap.git

Basic Usage

We will use the iris dataset for this example. Here, we load the data and train a simple lightgbm model on the dataset:

from sklearn.datasets import load_iris
import pandas as pd
import lightgbm as lgb
import numpy as np

# Define our dataset and target variable
data = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1)
data.rename({"target": "species"}, inplace=True, axis=1)
data["species"] = data["species"].astype("category")
target = data.pop("sepal length (cm)")

# Train our model
dtrain = lgb.Dataset(data=data, label=target)
lgbmodel = lgb.train(
    params={"seed": 1, "verbose": -1},
    train_set=dtrain,
    num_boost_round=10
)

# Define the function we wish to build shap values for.
model = lgbmodel.predict

preds = model(data)

We now have a model which takes a Pandas dataframe, and returns predictions. We can create an explainer that will use data as a background dataset to calculate the shap values of any dataset we wish:

import fastshap

ke = fastshap.KernelExplainer(model, data)
sv = ke.calculate_shap_values(data, verbose=False)

print(all(preds == sv.sum(1)))
## True

Stratifying the Background Set

We can select a subset of our data to act as a background set. By stratifying the background set on the results of the model output, we will usually get very similar results, while decreasing the caculation time drastically.

ke.stratify_background_set(5)
sv2 = ke.calculate_shap_values(
  data, 
  background_fold_to_use=0,
  verbose=False
)

print(np.abs(sv2 - sv).mean(0))
## [1.74764532e-03 1.61829094e-02 1.99534408e-03 4.02640884e-16
##  1.71084747e-02]

What we did is break up our background set into 10 different sets, stratified by the model output. We then used the first of these sets as our background set. We then compared the average difference between these shap values, and the shap values we obtained from using the entire dataset.

Choosing Batch Sizes

If the entire process was vectorized, it would require an array of size (# Samples * # Coalitions * # Background samples, # Columns). Where # Coalitions is the sum of the total number of coalitions that are going to be run. Even for small datasets, this becomes enormous. fastshap breaks this array up into chunks by splitting the process into a series of batches.

This is a list of the large arrays and their maximum size:

  • Global
    • Mask Matrix (# Coalitions, # Columns) dtype = int8
  • Outer Batch
    • Linear Targets (Total Coalition Combinations, Outer Batch Size) dtype = adaptive
  • Inner Batch
    • Model Evaluation Features (Inner Batch Size, # background samples) dtype = adaptive

The adaptive datatypes of the arrays above will be matched to the data types of the model output. Therefore, if your model returns float32, these arrays will be stored as float32. The final, returned shap values will also be returned as the datatype returned by the model.

These theoretical sizes can be calculated directly so that the user can determine appropriate batch sizes for their machine:

# Combines our background data back into 1 DataFrame
ke.stratify_background_set(1)
(
    mask_matrix_size, 
    linear_target_size, 
    inner_model_eval_set_size
) = ke.get_theoretical_array_expansion_sizes(
    outer_batch_size=150,
    inner_batch_size=150,
    n_coalition_sizes=3,
    background_fold_to_use=None,
)

print(
  np.product(linear_target_size) + np.product(inner_model_eval_set_size)
)
## 92100

For the iris dataset, even if we sent the entire set (150 rows) through as one batch, we only need 92100 elements stored in arrays. This is manageable on most machines. However, this number grows extremely quickly with the samples and number of columns. It is highly advised to determine a good batch scheme before running this process.

Specifying a Custom Linear Model

Any linear model available from sklearn.linear_model can be used to calculate the shap values. If you wish for some sparsity in the shap values, you can use Lasso regression:

from sklearn.linear_model import Lasso

# Use our entire background set
ke.stratify_background_set(1)
sv_lasso = ke.calculate_shap_values(
  data, 
  background_fold_to_use=0,
  linear_model=Lasso(alpha=0.1),
  verbose=False
)

print(sv_lasso[0,:])
## [-0.         -0.33797832 -0.         -0.14634971  5.84333333]

The default model used is sklearn.linear_model.LinearRegression.

Owner
Samuel Wilson
Samuel Wilson
Simulations for Turring patterns on an apically expanding domain. T

Turing patterns on expanding domain Simulations for Turring patterns on an apically expanding domain. The details about the models and numerical imple

Yue Liu 0 Aug 03, 2021
cisip-FIRe - Fast Image Retrieval

Fast Image Retrieval (FIRe) is an open source image retrieval project release by Center of Image and Signal Processing Lab (CISiP Lab), Universiti Malaya. This project implements most of the major bi

CISiP Lab 39 Nov 25, 2022
Python project to take sound as input and output as RGB + Brightness values suitable for DMX

sound-to-light Python project to take sound as input and output as RGB + Brightness values suitable for DMX Current goals: Get one pixel working: Vary

Bobby Cox 1 Nov 17, 2021
Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized Codes

Unleashing Transformers: Parallel Token Prediction with Discrete Absorbing Diffusion for Fast High-Resolution Image Generation from Vector-Quantized C

Sam Bond-Taylor 139 Jan 04, 2023
A PyTorch implementation of Radio Transformer Networks from the paper "An Introduction to Deep Learning for the Physical Layer".

An Introduction to Deep Learning for the Physical Layer An usable PyTorch implementation of the noisy autoencoder infrastructure in the paper "An Intr

Gram.AI 120 Nov 21, 2022
Code for Boundary-Aware Segmentation Network for Mobile and Web Applications

BASNet Boundary-Aware Segmentation Network for Mobile and Web Applications This repository contain implementation of BASNet in tensorflow/keras. comme

Hamid Ali 8 Nov 24, 2022
Weakly Supervised 3D Object Detection from Point Cloud with Only Image Level Annotation

SCCKTIM Weakly Supervised 3D Object Detection from Point Cloud with Only Image-Level Annotation Our code will be available soon. The class knowledge t

1 Nov 12, 2021
Personals scripts using ageitgey/face_recognition

HOW TO USE pip3 install requirements.txt Add some pictures of known people in the folder 'people' : a) Create a folder called by the name of the perso

Antoine Bollengier 1 Jan 06, 2022
Official code for the paper "Why Do Self-Supervised Models Transfer? Investigating the Impact of Invariance on Downstream Tasks".

Why Do Self-Supervised Models Transfer? Investigating the Impact of Invariance on Downstream Tasks This repository contains the official code for the

Linus Ericsson 11 Dec 16, 2022
Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild

Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild

1.1k Jan 03, 2023
🧮 Matrix Factorization for Collaborative Filtering is just Solving an Adjoint Latent Dirichlet Allocation Model after All

Accompanying source code to the paper "Matrix Factorization for Collaborative Filtering is just Solving an Adjoint Latent Dirichlet Allocation Model A

Florian Wilhelm 39 Dec 03, 2022
Chainer Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.)

fcn - Fully Convolutional Networks Chainer implementation of Fully Convolutional Networks. Installation pip install fcn Inference Inference is done as

Kentaro Wada 218 Oct 27, 2022
Learning Representational Invariances for Data-Efficient Action Recognition

Learning Representational Invariances for Data-Efficient Action Recognition Official PyTorch implementation for Learning Representational Invariances

Virginia Tech Vision and Learning Lab 27 Nov 22, 2022
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
Point Cloud Registration Network

PCRNet: Point Cloud Registration Network using PointNet Encoding Source Code Author: Vinit Sarode and Xueqian Li Paper | Website | Video | Pytorch Imp

ViNiT SaRoDe 59 Nov 19, 2022
Repository for MeshTalk supplemental material and code once the (already approved) 16 GHS captures our lab will make publicly available are released.

meshtalk This repository contains code to run MeshTalk for face animation from audio. If you use MeshTalk, please cite @inproceedings{richard2021mesht

Meta Research 221 Jan 06, 2023
Official PyTorch implementation of CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds

CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds Introduction This is the official PyTorch implementation of o

Yijia Weng 96 Dec 07, 2022
Official Implementation of "Learning Disentangled Behavior Embeddings"

DBE: Disentangled-Behavior-Embedding Official implementation of Learning Disentangled Behavior Embeddings (NeurIPS 2021). Environment requirement The

Mishne Lab 12 Sep 28, 2022
The code repository for "RCNet: Reverse Feature Pyramid and Cross-scale Shift Network for Object Detection" (ACM MM'21)

RCNet: Reverse Feature Pyramid and Cross-scale Shift Network for Object Detection (ACM MM'21) By Zhuofan Zong, Qianggang Cao, Biao Leng Introduction F

TempleX 9 Jul 30, 2022
"Learning and Analyzing Generation Order for Undirected Sequence Models" in Findings of EMNLP, 2021

undirected-generation-dev This repo contains the source code of the models described in the following paper "Learning and Analyzing Generation Order f

Yichen Jiang 0 Mar 25, 2022