Implementation of the paper "Shapley Explanation Networks"

Overview

Shapley Explanation Networks

Implementation of the paper "Shapley Explanation Networks" at ICLR 2021. Note that this repo heavily uses the experimental feature of named tensors in PyTorch. As it was really confusing to implement the ideas for the authors, we find it tremendously easier to use this feature.

Dependencies

For running only ShapNets, one would mostly only need PyTorch, NumPy, and SciPy.

Usage

For a Shapley Module:

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule

b_size = 3
features = 4
out = 1
dims = ModuleDimensions(
    features=features,
    in_channel=1,
    out_channel=out
)

sm = ShapleyModule(
    inner_function=nn.Linear(features, out),
    dimensions=dims
)
sm(torch.randn(b_size, features), explain=True)

For a Shallow ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, OverlappingShallowShapleyNetwork

batch_size = 32
class_num = 10
dim = 32

overlapping_modules = [
    ShapleyModule(
        inner_function=nn.Sequential(nn.Linear(2, class_num)),
        dimensions=ModuleDimensions(
            features=2, in_channel=1, out_channel=class_num
        ),
    ) for _ in range(dim * (dim - 1) // 2)
]
shallow_shapnet = OverlappingShallowShapleyNetwork(
    list_modules=overlapping_modules
)
inputs = torch.randn(batch_size, dim, ), )
shallow_shapnet(torch.randn(batch_size, dim, ), )
output, bias = shallow_shapnet(inputs, explain=True, )

For a Deep ShapNet

import torch
import torch.nn as nn
from ShapNet.utils import ModuleDimensions
from ShapNet import ShapleyModule, ShallowShapleyNetwork, DeepShapleyNetwork

dim = 32
dim_input_channels = 1
class_num = 10
inputs = torch.randn(32, dim, ), )


dims = ModuleDimensions(
    features=dim,
    in_channel=dim_input_channels,
    out_channel=class_num
)
deep_shapnet = DeepShapleyNetwork(
    list_shapnets=[
        ShallowShapleyNetwork(
            module_dict=nn.ModuleDict({
                "(0, 2)": ShapleyModule(
                    inner_function=nn.Linear(2, class_num),
                    dimensions=ModuleDimensions(
                        features=2, in_channel=1, out_channel=class_num
                    )
                )},
            ),
            dimensions=ModuleDimensions(dim, 1, class_num)
        ),
    ],
)
deep_shapnet(inputs)
outputs = deep_shapnet(inputs, explain=True, )

For a vision model:

import numpy as np
import torch
import torch.nn as nn

# =============================================================================
# Imports {\sc ShapNet}
# =============================================================================
from ShapNet import DeepConvShapNet, ShallowConvShapleyNetwork, ShapleyModule
from ShapNet.utils import ModuleDimensions, NAME_HEIGHT, NAME_WIDTH, \
    process_list_sizes

num_channels = 3
num_classes = 10
height = 32
width = 32
list_channels = [3, 16, 10]
pruning = [0.2, 0.]
kernel_sizes = process_list_sizes([2, (1, 3), ])
dilations = process_list_sizes([1, 2])
paddings = process_list_sizes([0, 0])
strides = process_list_sizes([1, 1])

args = {
    "list_shapnets": [
        ShallowConvShapleyNetwork(
            shapley_module=ShapleyModule(
                inner_function=nn.Sequential(
                    nn.Linear(
                        np.prod(kernel_sizes[i]) * list_channels[i],
                        list_channels[i + 1]),
                    nn.LeakyReLU()
                ),
                dimensions=ModuleDimensions(
                    features=int(np.prod(kernel_sizes[i])),
                    in_channel=list_channels[i],
                    out_channel=list_channels[i + 1])
            ),
            reference_values=None,
            kernel_size=kernel_sizes[i],
            dilation=dilations[i],
            padding=paddings[i],
            stride=strides[i]
        ) for i in range(len(list_channels) - 1)
    ],
    "reference_values": None,
    "residual": False,
    "named_output": False,
    "pruning": pruning
}

dcs = DeepConvShapNet(**args)

Citation

If this is useful, you could cite our work as

@inproceedings{
wang2021shapley,
title={Shapley Explanation Networks},
author={Rui Wang and Xiaoqian Wang and David I. Inouye},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=vsU0efpivw}
}
Owner
Prof. David I. Inouye's research lab at Purdue University.
Implementation of Gans

GAN Generative Adverserial Networks are an approach to generative data modelling using Deep learning methods. I have currently implemented : DCGAN on

Sibam Parida 5 Sep 07, 2021
The official implementation of Equalization Loss v1 & v2 (CVPR 2020, 2021) based on MMDetection.

The Equalization Losses for Long-tailed Object Detection and Instance Segmentation This repo is official implementation CVPR 2021 paper: Equalization

Jingru Tan 129 Dec 16, 2022
A dead simple python wrapper for darknet that works with OpenCV 4.1, CUDA 10.1

What Dead simple python wrapper for Yolo V3 using AlexyAB's darknet fork. Works with CUDA 10.1 and OpenCV 4.1 or later (I use OpenCV master as of Jun

Pliable Pixels 6 Jan 12, 2022
People movement type classifier with YOLOv4 detection and SORT tracking.

Movement classification The goal of this project would be movement classification of people, in other words, walking (normal and fast) and running. Yo

4 Sep 21, 2021
A library of scripts that interact with the PythonTurtle module to create games, drawings, and more

TurtleLib TurtleLib is a library of scripts that interact with the PythonTurtle module to create games, drawings, and more! Using the Scripts Copy or

1 Jan 15, 2022
本步态识别系统主要基于GaitSet模型进行实现

本步态识别系统主要基于GaitSet模型进行实现。在尝试部署本系统之前,建立理解GaitSet模型的网络结构、训练和推理方法。 系统的实现效果如视频所示: 演示视频 由于模型较大,部分模型文件存储在百度云盘。 链接提取码:33mb 具体部署过程 1.下载代码 2.安装requirements.txt

16 Oct 22, 2022
AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

Frank Liu 26 Oct 13, 2022
An implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 984 Dec 16, 2022
Semantic similarity computation with different state-of-the-art metrics

Semantic similarity computation with different state-of-the-art metrics Description • Installation • Usage • License Description TaxoSS is a semantic

6 Jun 22, 2022
A library for researching neural networks compression and acceleration methods.

A library for researching neural networks compression and acceleration methods.

Intel Labs 100 Dec 29, 2022
Replication Code for "Self-Supervised Bug Detection and Repair" NeurIPS 2021

Self-Supervised Bug Detection and Repair This is the reference code to replicate the research in Self-Supervised Bug Detection and Repair in NeurIPS 2

Microsoft 85 Dec 24, 2022
E2C implementation in PyTorch

Embed to Control implementation in PyTorch Paper can be found here: https://arxiv.org/abs/1506.07365 You will need a patched version of OpenAI Gym in

Yicheng Luo 42 Dec 12, 2022
Answering Open-Domain Questions of Varying Reasoning Steps from Text

This repository contains the authors' implementation of the Iterative Retriever, Reader, and Reranker (IRRR) model in the EMNLP 2021 paper "Answering Open-Domain Questions of Varying Reasoning Steps

26 Dec 22, 2022
ATOMIC 2020: On Symbolic and Neural Commonsense Knowledge Graphs

(Comet-) ATOMIC 2020: On Symbolic and Neural Commonsense Knowledge Graphs Paper Jena D. Hwang, Chandra Bhagavatula, Ronan Le Bras, Jeff Da, Keisuke Sa

AI2 152 Dec 27, 2022
Implementation of character based convolutional neural network

Character Based CNN This repo contains a PyTorch implementation of a character-level convolutional neural network for text classification. The model a

Ahmed BESBES 248 Nov 21, 2022
Chinese license plate recognition

AgentCLPR 简介 一个基于 ONNXRuntime、AgentOCR 和 License-Plate-Detector 项目开发的中国车牌检测识别系统。 车牌识别效果 支持多种车牌的检测和识别(其中单层车牌识别效果较好): 单层车牌: [[[[373, 282], [69, 284],

AgentMaker 26 Dec 25, 2022
Offcial implementation of "A Hybrid Video Anomaly Detection Framework via Memory-Augmented Flow Reconstruction and Flow-Guided Frame Prediction, ICCV-2021".

HF2-VAD Offcial implementation of "A Hybrid Video Anomaly Detection Framework via Memory-Augmented Flow Reconstruction and Flow-Guided Frame Predictio

76 Dec 21, 2022
Code for SALT: Stackelberg Adversarial Regularization, EMNLP 2021.

SALT: Stackelberg Adversarial Regularization Code for Adversarial Regularization as Stackelberg Game: An Unrolled Optimization Approach, EMNLP 2021. R

Simiao Zuo 10 Jan 10, 2022
Continual Learning of Electronic Health Records (EHR).

Continual Learning of Longitudinal Health Records Repo for reproducing the experiments in Continual Learning of Longitudinal Health Records (2021). Re

Jacob 7 Oct 21, 2022