Explainability for Vision Transformers (in PyTorch)

Overview

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path  --head_fusion  --discard_ratio  --category_index 

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

Owner
Jacob Gildenblat
Machine learning / Computer Vision developer.
Jacob Gildenblat
Code repo for "FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation" (ICCV 2021)

FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation (ICCV 2021) This repository contains the implementation of th

Yuhang Zang 21 Dec 17, 2022
Simple tutorials using Google's TensorFlow Framework

TensorFlow-Tutorials Introduction to deep learning based on Google's TensorFlow framework. These tutorials are direct ports of Newmu's Theano Tutorial

Nathan Lintz 6k Jan 06, 2023
Totally Versatile Miscellanea for Pytorch

Totally Versatile Miscellania for PyTorch Thomas Viehmann [email protected] Thi

Thomas Viehmann 428 Dec 28, 2022
Learning Skeletal Articulations with Neural Blend Shapes

This repository provides an end-to-end library for automatic character rigging and blend shapes generation as well as a visualization tool. It is based on our work Learning Skeletal Articulations wit

Peizhuo 504 Dec 30, 2022
Convolutional 2D Knowledge Graph Embeddings resources

ConvE Convolutional 2D Knowledge Graph Embeddings resources. Paper: Convolutional 2D Knowledge Graph Embeddings Used in the paper, but do not use thes

Tim Dettmers 586 Dec 24, 2022
Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Self-supervised learning Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive loss

Arijit Das 2 Mar 26, 2022
Cognition-aware Cognate Detection

Cognition-aware Cognate Detection The repository which contains our code for our EACL 2021 paper titled, "Cognition-aware Cognate Detection". This wor

Prashant K. Sharma 1 Feb 01, 2022
Colour detection is necessary to recognize objects, it is also used as a tool in various image editing and drawing apps.

Colour Detection On Image Colour detection is the process of detecting the name of any color. Simple isn’t it? Well, for humans this is an extremely e

Astitva Veer Garg 1 Jan 13, 2022
(Arxiv 2021) NeRF--: Neural Radiance Fields Without Known Camera Parameters

NeRF--: Neural Radiance Fields Without Known Camera Parameters Project Page | Arxiv | Colab Notebook | Data Zirui Wang¹, Shangzhe Wu², Weidi Xie², Min

Active Vision Laboratory 411 Dec 26, 2022
The UI as a mobile display for OP25

OP25 Mobile Control Head A 'remote' control head that interfaces with an OP25 instance. We take advantage of some data end-points left exposed for the

Sarah Rose Giddings 13 Dec 28, 2022
House_prices_kaggle - Predict sales prices and practice feature engineering, RFs, and gradient boosting

House Prices - Advanced Regression Techniques Predicting House Prices with Machine Learning This project is build to enhance my knowledge about machin

Gurpreet Singh 1 Jan 01, 2022
Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning And private Server services

Tensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuning

MaCan 4.2k Dec 29, 2022
An LSTM based GAN for Human motion synthesis

GAN-motion-Prediction An LSTM based GAN for motion synthesis has a few issues reading H3.6M data from A.Jain et al , will fix soon. Prediction of the

Amogh Adishesha 9 Jun 17, 2022
Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution

Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution Abstract Within the Latin (and ancient Greek) production, it is well

4 Dec 03, 2022
Anomaly detection analysis and labeling tool, specifically for multiple time series (one time series per category)

taganomaly Anomaly detection labeling tool, specifically for multiple time series (one time series per category). Taganomaly is a tool for creating la

Microsoft 272 Dec 17, 2022
Code for "Diffusion is All You Need for Learning on Surfaces"

Source code for "Diffusion is All You Need for Learning on Surfaces", by Nicholas Sharp Souhaib Attaiki Keenan Crane Maks Ovsjanikov NOTE: the linked

Nick Sharp 247 Dec 28, 2022
A PyTorch Image-Classification With AlexNet And ResNet50.

PyTorch 图像分类 依赖库的下载与安装 在终端中执行 pip install -r -requirements.txt 完成项目依赖库的安装 使用方式 数据集的准备 STL10 数据集 下载:STL-10 Dataset 存储位置:将下载后的数据集中 train_X.bin,train_y.b

FYH 4 Feb 22, 2022
[CVPR 2021] Anycost GANs for Interactive Image Synthesis and Editing

Anycost GAN video | paper | website Anycost GANs for Interactive Image Synthesis and Editing Ji Lin, Richard Zhang, Frieder Ganz, Song Han, Jun-Yan Zh

MIT HAN Lab 726 Dec 28, 2022
Keyword-BERT: Keyword-Attentive Deep Semantic Matching

project discription An implementation of the Keyword-BERT model mentioned in my paper Keyword-Attentive Deep Semantic Matching (Plz cite this github r

1 Nov 14, 2021