Improving Deep Network Debuggability via Sparse Decision Layers

Overview

Improving Deep Network Debuggability via Sparse Decision Layers

This repository contains the code for our paper:

Leveraging Sparse Linear Layers for Debuggable Deep Networks
Eric Wong*, Shibani Santurkar*, Aleksander Madry
Paper: http://arxiv.org/abs/2105.04857
Blog posts: Part1 and Part2

Pipeline overview

@article{wong2021leveraging,
  title={Leveraging Sparse Linear Layers for Debuggable Deep Networks},
  author={Wong, Eric and Santurkar, Shibani and M{\k{a}}dry, Aleksander},
  journal={arXiv preprint arXiv:2105.04857},
  year={2021}
}

Getting started

Our code relies on the MadryLab public robustness library, as well as the glm_saga library which will be automatically installed when you follow the instructions below. The glm_saga library contains a standalone implementation of our sparse GLM solver.

  1. Clone our repo: git clone https://github.com/microsoft/DebuggableDeepNetworks.git

  2. Setup the lucent submodule using: git submodule update --init --recursive

  3. We recommend using conda for dependencies:

    conda env create -f environment.yml
    conda activate debuggable
    

Training sparse decision layers

Contents:

  • main.py fits a sparse decision layer on top of the deep features of the specified pre-trained (language/vision) deep network
  • helpers/ has some helper functions for loading datasets, models, and features
  • language/ has some additional code for handling language models and datasets

To run the settings in our paper, you can use the following commands:

# Sentiment classification
python main.py --dataset sst --dataset-path   --dataset-type language --model-path barissayil/bert-sentiment-analysis-sst --arch bert --out-path ./tmp/sst/ --cache

# Toxic comment classification (biased)
python main.py --dataset jigsaw-toxic --dataset-path   --dataset-type language --model-path unitary/toxic-bert --arch bert --out-path ./tmp/jigsaw-toxic/ --cache --balance

# Toxic comment classification (unbiased)
python main.py --dataset jigsaw-alt-toxic --dataset-path   --dataset-type language --model-path unitary/unbiased-toxic-roberta --arch roberta --out-path ./tmp/unbiased-jigsaw-toxic/ --cache --balance

# Places-10 
python main.py --dataset places-10 --dataset-path  --dataset-type vision --model-path  --arch resnet50 --out-path ./tmp/places/ --cache

# ImageNet
python main.py --dataset imagenet --dataset-path  --dataset-type vision --model-path  --arch resnet50 --out-path ./tmp/imagenet/ --cache

Interpreting deep features

After fitting a sparse GLM with one of the above commands, we provide some notebooks for inspecting and visualizing the resulting features. See inspect_vision_models.ipynb and inspect_language_models.ipynb for the vision and language settings respectively.

Maintainers

Owner
Madry Lab
Towards a Principled Science of Deep Learning
Madry Lab
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
PyTorch implementation of Rethinking Positional Encoding in Language Pre-training

TUPE PyTorch implementation of Rethinking Positional Encoding in Language Pre-training. Quickstart Clone this repository. git clone https://github.com

Jake Tae 5 Jan 27, 2022
FeTaQA: Free-form Table Question Answering

FeTaQA: Free-form Table Question Answering FeTaQA is a Free-form Table Question Answering dataset with 10K Wikipedia-based {table, question, free-form

Language, Information, and Learning at Yale 40 Dec 13, 2022
code for ICCV 2021 paper 'Generalized Source-free Domain Adaptation'

G-SFDA Code (based on pytorch 1.3) for our ICCV 2021 paper 'Generalized Source-free Domain Adaptation'. [project] [paper]. Dataset preparing Download

Shiqi Yang 84 Dec 26, 2022
Implementation of ReSeg using PyTorch

Implementation of ReSeg using PyTorch ReSeg: A Recurrent Neural Network-based Model for Semantic Segmentation Pascal-Part Annotations Pascal VOC 2010

Onur Kaplan 46 Nov 23, 2022
Code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning".

0. Introduction This repository contains the source code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning". Notes The netwo

NetX Group 68 Nov 24, 2022
DEMix Layers for Modular Language Modeling

DEMix This repository contains modeling utilities for "DEMix Layers: Disentangling Domains for Modular Language Modeling" (Gururangan et. al, 2021). T

Suchin 43 Nov 11, 2022
Code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectrograms, using the PyTorch Lightning.

stereoEEG2speech We provide code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectro

15 Nov 11, 2022
Dataset used in "PlantDoc: A Dataset for Visual Plant Disease Detection" accepted in CODS-COMAD 2020

PlantDoc: A Dataset for Visual Plant Disease Detection This repository contains the Cropped-PlantDoc dataset used for benchmarking classification mode

Pratik Kayal 109 Dec 29, 2022
[ICRA 2022] CaTGrasp: Learning Category-Level Task-Relevant Grasping in Clutter from Simulation

This is the official implementation of our paper: Bowen Wen, Wenzhao Lian, Kostas Bekris, and Stefan Schaal. "CaTGrasp: Learning Category-Level Task-R

Bowen Wen 199 Jan 04, 2023
Source codes of CenterTrack++ in 2021 ICME Workshop on Big Surveillance Data Processing and Analysis

MOT Tracked object bounding box association (CenterTrack++) New association method based on CenterTrack. Two new branches (Tracked Size and IOU) are a

36 Oct 04, 2022
use machine learning to recognize gesture on raspberrypi

Raspberrypi_Gesture-Recognition use machine learning to recognize gesture on raspberrypi 說明 利用 tensorflow lite 訓練手部辨識模型 分辨 "剪刀"、"石頭"、"布" 之手勢 再將訓練模型匯入

1 Dec 10, 2021
Lane follower: Lane-detector (OpenCV) + Object-detector (YOLO5) + CAN-bus

Lane Follower This code is for the lane follower, including perception and control, as shown below. Environment Hardware Industrial Camera Intel-NUC(1

Siqi Fan 3 Jul 07, 2022
Line-level Handwritten Text Recognition (HTR) system implemented with TensorFlow.

Line-level Handwritten Text Recognition with TensorFlow This model is an extended version of the Simple HTR system implemented by @Harald Scheidl and

Hoàng Tùng Lâm (Linus) 72 May 07, 2022
Melanoma Skin Cancer Detection using Convolutional Neural Networks and Transfer Learning🕵🏻‍♂️

This is a Kaggle competition in which we have to identify if the given lesion image is malignant or not for Melanoma which is a type of skin cancer.

Vipul Shinde 1 Jan 27, 2022
Repository for the AugmentedPCA Python package.

Overview This Python package provides implementations of Augmented Principal Component Analysis (AugmentedPCA) - a family of linear factor models that

Billy Carson 6 Dec 07, 2022
Official implementation of "Open-set Label Noise Can Improve Robustness Against Inherent Label Noise" (NeurIPS 2021)

Open-set Label Noise Can Improve Robustness Against Inherent Label Noise NeurIPS 2021: This repository is the official implementation of ODNL. Require

Hongxin Wei 12 Dec 07, 2022
A toolkit for controlling Euro Truck Simulator 2 with python to develop self-driving algorithms.

europilot Overview Europilot is an open source project that leverages the popular Euro Truck Simulator(ETS2) to develop self-driving algorithms. A con

1.4k Jan 04, 2023
NeurIPS 2021 Datasets and Benchmarks Track

AP-10K: A Benchmark for Animal Pose Estimation in the Wild Introduction | Updates | Overview | Download | Training Code | Key Questions | License Intr

AP-10K 82 Dec 11, 2022
image scene graph generation benchmark

Scene Graph Benchmark in PyTorch 1.7 This project is based on maskrcnn-benchmark Highlights Upgrad to pytorch 1.7 Multi-GPU training and inference Bat

Microsoft 303 Dec 27, 2022