TimeSHAP explains Recurrent Neural Network predictions.

Overview

TimeSHAP

TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes event/timestamp- feature-, and cell-level attributions. As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithm based on Shapley Values, that finds a subset of consecutive, recent events that contribute the most to the decision.

This repository is the code implementation of the TimeSHAP algorithm present in the paper TimeSHAP: Explaining Recurrent Models through Sequence Perturbations published at KDD 2021.

Links to the paper here, and to the video presentation here.

Install TimeSHAP

Clone the repository into a local directory using:

git clone https://github.com/feedzai/timeshap.git

Install the package using pip:

pip install timeshap

To test your installation, start a Python session in your terminal using

python

And import TimeSHAP

import timeshap

TimeSHAP in 30 seconds

Inputs

  • Model being explained;
  • Instance(s) to explain;
  • Background instance.

Outputs

  • Local pruning output; (explaining a single instance)
  • Local event explanations; (explaining a single instance)
  • Local feature explanations; (explaining a single instance)
  • Global pruning statistics; (explaining multiple instances)
  • Global event explanations; (explaining multiple instances)
  • Global feature explanations; (explaining multiple instances)

Model Interface

In order for TimeSHAP to explain a model, an entry point must be provided. This Callable entry point must receive a 3-D numpy array, (#sequences; #sequence length; #features) and return a 2-D numpy array (#sequences; 1) with the corresponding score of each sequence. In addition, to make TimeSHAP more optimized, it is possible to return the hidden state of the model together with the score (if applicable), although this is optional.

TimeSHAP is able to explain any black-box model as long as it complies with the previously described interface, including both PyTorch and TensorFlow models, both examplified in our tutorials (PyTorch, TensorFlow).

Example provided in our tutorials:

  • TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)
f = lambda x: model.predict(x)
  • Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers

In order to facilitate the interface between models and TimeSHAP, TimeSHAP implements ModelWrappers. These wrappers, used on the PyTorch tutorial notebook, allow for greater flexibility of explained models as they allow:

  • Batching logic: useful when using very large inputs or NSamples, which cannot fit on GPU memory, and therefore batching mechanisms are required;
  • Input format/type: useful when your model does not work with numpy arrays. This is the case of our provided PyToch example;

TimeSHAP Explanation Methods

TimeSHAP offers several methods to use depending on the desired explanations. Local methods provide detailed view of a model decision corresponding to a specific sequence being explained. Global methods aggregate local explanations of a given dataset to present a global view of the model.

Local Explanations

Pruning

local_pruning() performs the pruning algorithm on a given sequence with a given user defined tolerance and returns the pruning index along the information for plotting.

plot_temp_coalition_pruning() plots the pruning algorithm information calculated by local_pruning().

Event level explanations

local_event() calculates event level explanations of a given sequence with the user-given parameteres and returns the respective event-level explanations.

plot_event_heatmap() plots the event-level explanations calculated by local_event().

Feature level explanations

local_feat() calculates feature level explanations of a given sequence with the user-given parameteres and returns the respective feature-level explanations.

plot_feat_barplot() plots the feature-level explanations calculated by local_feat().

Cell level explanations

local_cell_level() calculates cell level explanations of a given sequence with the respective event- and feature-level explanations and user-given parameteres, returing the respective cell-level explanations.

plot_cell_level() plots the feature-level explanations calculated by local_cell_level().

Local Report

local_report() calculates TimeSHAP local explanations for a given sequence and plots them.

Global Explanations

Global pruning statistics

prune_all() performs the pruning algorithm on multiple given sequences.

pruning_statistics() calculates the pruning statistics for several user-given pruning tolerances using the pruning data calculated by prune_all(), returning a pandas.DataFrame with the statistics.

Global event level explanations

event_explain_all() calculates TimeSHAP event level explanations for multiple instances given user defined parameters.

plot_global_event() plots the global event-level explanations calculated by event_explain_all().

Global feature level explanations

feat_explain_all() calculates TimeSHAP feature level explanations for multiple instances given user defined parameters.

plot_global_feat() plots the global feature-level explanations calculated by feat_explain_all().

Global report

global_report() calculates TimeSHAP explanations for multiple instances, aggregating the explanations on two plots and returning them.

Tutorial

In order to demonstrate TimeSHAP interfaces and methods, you can consult AReM.ipynb. In this tutorial we get an open-source dataset, process it, train Pytorch recurrent model with it and use TimeSHAP to explain it, showcasing all previously described methods.

Additionally, we also train a TensorFlow model on the same dataset AReM_TF.ipynb.

Repository Structure

Citing TimeSHAP

@inproceedings{bento2021timeshap,
    author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},
    title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},
    year = {2021},
    isbn = {9781450383325},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3447548.3467166},
    doi = {10.1145/3447548.3467166},
    booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},
    pages = {2565–2573},
    numpages = {9},
    keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},
    location = {Virtual Event, Singapore},
    series = {KDD '21}
}
Comments
  • Error in running the example notebook (AReM_TF)

    Error in running the example notebook (AReM_TF)

    Nice work! I have been trying to run one of the tutorial notebooks in the repository (i.e., AReM_TF), but I faced an error. The notebook chunk that produces the error is:

    from timeshap.explainer import local_report, local_pruning
    
    pruning_dict = {'tol': 0.025}
    event_dict = {'rs': 42, 'nsamples': 320}
    feature_dict = {'rs': 42, 'nsamples': 320, 'feature_names': model_features, 'plot_features': plot_feats}
    cell_dict = {'rs': 42, 'nsamples': 320, 'top_x_feats': 2, 'top_x_events': 2}
    local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict,cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    

    and the produced error is as follows: image image

    It would be great if you could help me with this error.

    opened by aminnayebi 7
  • Unable to install timeshap package

    Unable to install timeshap package

    Hello @feedzaiadmin , @saleiro ,

    I am unable to run command "pip install timeshap" on my ubuntu system. It throws me below error:

    ERROR: Could not find a version that satisfies the requirement timeshap (from versions: none) ERROR: No matching distribution found for timeshap

    Seems to me a compatibility issue. I have tried with python versions - 3.6,3.8. Is there any specific version which supports it? Will be waiting for response.

    Thanks

    opened by vishants98 5
  • How to adapt the transformation function to account for variable sequence length?

    How to adapt the transformation function to account for variable sequence length?

    I am trying to use TimeSHAP on my use case. Per my understanding, in AReM example, the way you transform the data using the df_to_numpy function is to make a prediction for the last value of the sequence – see the screen below:

    image In the case of AReM tutorial data, the predictions are based on the whole sequence - all rows (rows ID 1-10) are being used for sequence ID 1 (light blue color) and the predictions are made for the Timestamp 10 (dark blue color; rows id 10). Later the light orange color is used (Row IDs 11-20) to predict a label marked as dark orange color (Row ID 20).

    In the case of my use case, the model predicts on a rolling-window basis and I would need predictions for every row (not only for a sequence). See the screen and explanation below. image Let's say my rolling window is 6 and Row IDs 1-6 (light green) are used to predict row 7 (dark green), later Row IDs 2-7 (light grey) are being used to predict Row ID 8 (dark grey), etc. When a new Sequence starts, we repeat the process, so we take Row IDs 11-16 and predict Row ID 17, etc. For my use case, it's important to evaluate the predictions for every Row ID, not only for the whole sequence.

    The problem which I am facing is that when I try to run the function get_avg_score_with_avg_event on the data defined as in the picture above I am getting the following error: image

    The way my data is transformed from 2D into 3D format is defined by the function below: image

    My question is whether it’s possible to make TimeSHAP work for the data which is transformed in a way described in my use case? When I use the transformation which is defined in your function df_to_numpy, I am not getting an error, however, it is not adapted to my use case.

    opened by grzechowiak 4
  • Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Foremost, thanks for the library, really great job!

    I am trying to reproduce your TimeSHAP Tutorial for TF and I am having an issue in the section for Global Explanations when running the global_report() function - screen below:

    image

    The error refers to encoding the \u2264 character which is a sign: <=. I was trying to solve that myself by modifying the pruning.py function according to the error by adding encoding="utf-8" in line 326 with open(file_path, 'a', newline='') as file:, however it didn't solve the problem. Any advice is very welcome!

    Also, for consistency, I want to mention that I had a problem loading the data - showing the error screens below. I was able to solve the problem only by deleting 2 datasets: cycling/dataset9.csv and cycling/dataset14csv and the rest of the code worked.

    1/2 image 2/2 image

    opened by grzechowiak 4
  • Timeshap for regression

    Timeshap for regression

    I am working on a time series forecasting problem using LSTM. Can I use timeshap for such a regression problem? Do you by chance have a demo for regression?

    Thanks

    opened by mgorjis 3
  • Plot Coalition Pruning is not working

    Plot Coalition Pruning is not working

    Hi, I am doing a project as part of a Master Thesis: I was testing your introductory notebook with the Tensorflow implementation, but this is not working because of an error raised by one of yours plot functions. The funny thing is that your other notebook, i.e., the one with the torch model, is working fine. I managed to find the issue and it was raised by an inner method called solve_negatives_method in src/timeshap/plot/pruning.py which would raise an InvalidKeyError caused by this specific row:

        df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    As far I understood, we are passing a list, usually made of only one value, to the pandas.DataFrame.at method which requires to pass an integer parameter and a column identifier. As such, one solution that works now could be:

        df.at[corresponding_row.index[0], 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    I want to thank you very much for the work done, and I will be happy to discuss further result of my thesis with you. :) Hoping for the best and looking forward to hear from you, Eric

    opened by Erhtric 2
  • TimeSHAP for text?

    TimeSHAP for text?

    I'm working with a 1-layer GRU for text classification that takes BERT embeddings at the input. Each input sequence is of the shape (sequence length, bert-embedding-dimension). I'm looking for word level attribution scores for each sequence's prediction. Currently with the captum integrated gradients and occlusion explainers, I get attribution scores that are almost always the last few words of each sequence. This seems like it's stemming from the directional processing of GRU - any thoughts?

    Do you think TimeSHAP would be applicable for my use case? I suppose I could consider each word as an event and each embedding dimension as a feature, then I could use the event level local explanations from the library? However, note that in my case, the events (i.e words) from the beginning of the sequence could be more important than those at the end of the sequence (i.e. most recent ones) - this violates the assumptions you use for your approximation (i.e pruning), so perhaps it's not applicable to text?

    opened by itsmemala 2
  • Intuition around pruning/baseline selection

    Intuition around pruning/baseline selection

    While calculating a global report, I'm currently running into errors that "Score difference between baseline and instance is too low < 0.1...Consider choosing another baseline." My baselines have been the average_event and average_sequence. I also notice that occasionally the values in the error change with different pruning tolerance, but not in a consistent way. Do you have advice for dealing with this? Thanks.

    opened by xydisla 2
  • CNN model

    CNN model

    I currently have a CNN model, and previously had to do some strange hacking to get time series importance values. Your package now shines a better light on this issue. For multivariate forecasts CNNs still fair well, sometimes better than RNNs, from what I can see there is no reason why using CNNs won't working using your software. Let me know if I have that wrong.

    opened by firmai 2
  • How to speed up TIMESHAP computation

    How to speed up TIMESHAP computation

    Hi all!

    The package itself is really interesting and intuitive to use. But I want to speed up TIMESHAP computation, Can I use gpu to calculate shapley values? I used a device parameter in TorchModelWrapper, but efficiency of GPU is too low to accelerate TIMESHAP computation. Any suggestion would be appreciated.

    opened by Changshu135 2
  • Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Hi all! When executing your TF example notebook without any changes, it fails at this line: local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)

    with the following error: InvalidIndexError: Int64Index([1], dtype='int64')

    Stack trace:

    TypeError                                 Traceback (most recent call last)
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621, in Index.get_loc(self, key, method, tolerance)
       3620 try:
    -> 3621     return self._engine.get_loc(casted_key)
       3622 except KeyError as err:
    
    File pandas/_libs/index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()
    
    File pandas/_libs/index.pyx:142, in pandas._libs.index.IndexEngine.get_loc()
    
    TypeError: 'Int64Index([1], dtype='int64')' is an invalid key
    
    During handling of the above exception, another exception occurred:
    
    InvalidIndexError                         Traceback (most recent call last)
    Input In [27], in <cell line: 7>()
          5 feature_dict = {'rs': 42, 'nsamples': 32000, 'feature_names': model_features, 'plot_features': plot_feats}
          6 cell_dict = {'rs': 42, 'nsamples': 32000, 'top_x_feats': 2, 'top_x_events': 2}
    ----> 7 local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    
    File ~/temp/timeshap/src/timeshap/explainer/local_methods.py:139, in local_report(f, data, pruning_dict, event_dict, feature_dict, cell_dict, entity_uuid, entity_col, time_col, model_features, baseline, verbose)
        137 pruning_idx = data.shape[1] + coal_prun_idx
        138 plot_lim = max(abs(coal_prun_idx)+10, 40)
    --> 139 pruning_plot = plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_lim)
        141 event_data = local_event(f, data, event_dict, entity_uuid, entity_col, baseline, pruning_idx)
        142 event_plot = plot_event_heatmap(event_data)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:53, in plot_temp_coalition_pruning(df, pruned_idx, plot_limit, solve_negatives)
         51 df = df[df['t (event index)'] >= -plot_limit]
         52 if solve_negatives:
    ---> 53     df = solve_negatives_method(df)
         55 base = (alt.Chart(df).encode(
         56     x=alt.X("t (event index)", axis=alt.Axis(title='t (event index)', labelFontSize=15,
         57                           titleFontSize=15)),
       (...)
         70 )
         71 )
         73 area_chart = base.mark_area(opacity=0.5)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:47, in plot_temp_coalition_pruning.<locals>.solve_negatives_method(df)
         45 for idx, row in negative_values.iterrows():
         46     corresponding_row = df[np.logical_and(df['t (event index)'] == row['t (event index)'], ~(df['Coalition'] == row['Coalition']))]
    ---> 47     df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
         48     df.at[idx, 'Shapley Value'] = 0
         49 return df
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2281, in _AtIndexer.__setitem__(self, key, value)
       2278     self.obj.loc[key] = value
       2279     return
    -> 2281 return super().__setitem__(key, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2236, in _ScalarAccessIndexer.__setitem__(self, key, value)
       2233 if len(key) != self.ndim:
       2234     raise ValueError("Not enough indexers for scalar access (setting)!")
    -> 2236 self.obj._set_value(*key, value=value, takeable=self._takeable)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/frame.py:3869, in DataFrame._set_value(self, index, col, value, takeable)
       3867 else:
       3868     series = self._get_item_cache(col)
    -> 3869     loc = self.index.get_loc(index)
       3871 # setitem_inplace will do validation that may raise TypeError
       3872 #  or ValueError
       3873 series._mgr.setitem_inplace(loc, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3628, in Index.get_loc(self, key, method, tolerance)
       3623         raise KeyError(key) from err
       3624     except TypeError:
       3625         # If we have a listlike key, _check_indexing_error will raise
       3626         #  InvalidIndexError. Otherwise we fall through and re-raise
       3627         #  the TypeError.
    -> 3628         self._check_indexing_error(key)
       3629         raise
       3631 # GH#42269
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:5637, in Index._check_indexing_error(self, key)
       5633 def _check_indexing_error(self, key):
       5634     if not is_scalar(key):
       5635         # if key is not a scalar, directly raise an error (the code below
       5636         # would convert to numpy arrays and raise later any way) - GH29926
    -> 5637         raise InvalidIndexError(key)
    
    InvalidIndexError: Int64Index([1], dtype='int64')
    
    opened by JulianKlug 2
Owner
Feedzai
Feedzai
The official PyTorch code for NeurIPS 2021 ML4AD Paper, "Does Thermal data make the detection systems more reliable?"

MultiModal-Collaborative (MMC) Learning Framework for integrating RGB and Thermal spectral modalities This is the official code for NeurIPS 2021 Machi

NeurAI 12 Nov 02, 2022
Official Pytorch implementation of "DivCo: Diverse Conditional Image Synthesis via Contrastive Generative Adversarial Network" (CVPR'21)

DivCo: Diverse Conditional Image Synthesis via Contrastive Generative Adversarial Network Pytorch implementation for our DivCo. We propose a simple ye

64 Nov 22, 2022
用opencv的dnn模块做yolov5目标检测,包含C++和Python两个版本的程序

yolov5-dnn-cpp-py yolov5s,yolov5l,yolov5m,yolov5x的onnx文件在百度云盘下载, 链接:https://pan.baidu.com/s/1d67LUlOoPFQy0MV39gpJiw 提取码:bayj python版本的主程序是main_yolov5.

365 Jan 04, 2023
MultiTaskLearning - Multi Task Learning for 3D segmentation

Multi Task Learning for 3D segmentation Perception stack of an Autonomous Drivin

2 Sep 22, 2022
Ratatoskr: Worcester Tech's conference scheduling system

Ratatoskr: Worcester Tech's conference scheduling system In Norse mythology, Ratatoskr is a squirrel who runs up and down the world tree Yggdrasil to

4 Dec 22, 2022
Joint Channel and Weight Pruning for Model Acceleration on Mobile Devices

Joint Channel and Weight Pruning for Model Acceleration on Mobile Devices Abstract For practical deep neural network design on mobile devices, it is e

11 Dec 30, 2022
Json2Xml tool will help you convert from json COCO format to VOC xml format in Object Detection Problem.

JSON 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Json2Xml t

Nguyễn Trường Lâu 6 Aug 22, 2022
Using contrastive learning and OpenAI's CLIP to find good embeddings for images with lossy transformations

The official code for the paper "Inverse Problems Leveraging Pre-trained Contrastive Representations" (to appear in NeurIPS 2021).

Sriram Ravula 26 Dec 10, 2022
Image to Image translation, image generataton, few shot learning

Semi-supervised Learning for Few-shot Image-to-Image Translation [paper] Abstract: In the last few years, unpaired image-to-image translation has witn

yaxingwang 49 Nov 18, 2022
Tensorflow 2 Object Detection API kurulumu, GPU desteği, custom model hazırlama

Tensorflow 2 Object Detection API Bu tutorial, TensorFlow 2.x'in kararlı sürümü olan TensorFlow 2.3'ye yöneliktir. Bu, görüntülerde / videoda nesne a

46 Nov 20, 2022
Unofficial PyTorch implementation of SimCLR by Google Brain

Unofficial PyTorch implementation of SimCLR by Google Brain

Rishabh Anand 2 Oct 13, 2021
PyTorch implementation for paper Neural Marching Cubes.

NMC PyTorch implementation for paper Neural Marching Cubes, Zhiqin Chen, Hao Zhang. Paper | Supplementary Material (to be updated) Citation If you fin

Zhiqin Chen 109 Dec 27, 2022
⚡ H2G-Net for Semantic Segmentation of Histopathological Images

H2G-Net This repository contains the code relevant for the proposed design H2G-Net, which was introduced in the manuscript "Hybrid guiding: A multi-re

André Pedersen 8 Nov 24, 2022
CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image.

CoReNet CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image. It produces coherent reconstructions, where all objec

Google Research 80 Dec 25, 2022
Simulation-based inference for the Galactic Center Excess

Simulation-based inference for the Galactic Center Excess Siddharth Mishra-Sharma and Kyle Cranmer Abstract The nature of the Fermi gamma-ray Galactic

Siddharth Mishra-Sharma 3 Jan 21, 2022
This repository contains the code for: RerrFact model for SciVer shared task

RerrFact This repository contains the code for: RerrFact model for SciVer shared task. Setup for Inference 1. Download SciFact database Download the S

Ashish Rana 1 May 22, 2022
Angular & Electron desktop UI framework. Angular components for native looking and behaving macOS desktop UI (Electron/Web)

Angular Desktop UI This is a collection for native desktop like user interface components in Angular, especially useful for Electron apps. It starts w

Marc J. Schmidt 49 Dec 22, 2022
Pytorch Implementation of Interaction Networks for Learning about Objects, Relations and Physics

Interaction-Network-Pytorch Pytorch Implementraion of Interaction Networks for Learning about Objects, Relations and Physics. Interaction Network is a

117 Nov 05, 2022
PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks

Code for the paper "PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks" (ICPR 2020)

Wenwen Yu 498 Dec 24, 2022
JudeasRx - graphical app for doing personalized causal medicine using the methods invented by Judea Pearl et al.

JudeasRX Instructions Read the references given in the Theory and Notation section below Fire up the Jupyter Notebook judeas-rx.ipynb The notebook dra

Robert R. Tucci 19 Nov 07, 2022