Learning Confidence for Out-of-Distribution Detection in Neural Networks

Overview

Learning Confidence Estimates for Neural Networks

This repository contains the code for the paper Learning Confidence for Out-of-Distribution Detection in Neural Networks. In this work, we demonstrate how to augment neural networks with a confidence estimation branch, which can be used to identify misclassified and out-of-distribution examples.

To learn confidence estimates during training, we provide the neural network with "hints" towards the correct output whenever it exhibits low confidence in its predictions. Hints are provided by pushing the prediction closer to the target distribution via interpolation, where the amount of interpolation proportional to the network's confidence that its prediction is correct. To discourage the network from always asking for free hints, a small penalty is applied whenever it is not confident. As a result, the network learns to only produce low confidence estimates when it is likely to make an incorrect prediction.

Bibtex:

@article{devries2018learning,
  title={Learning Confidence for Out-of-Distribution Detection in Neural Networks},
  author={DeVries, Terrance and Taylor, Graham W},
  journal={arXiv preprint arXiv:1802.04865},
  year={2018}
}

Results and Usage

We evalute our method on the task of out-of-distribution detection using three different neural network architectures: DenseNet, WideResNet, and VGG. CIFAR-10 and SVHN are used as the in-distribution datasets, while TinyImageNet, LSUN, iSUN, uniform noise, and Gaussian noise are used as the out-of-distribution datasets. Definitions of evaluation metrics can be found in the paper.

Dependencies

PyTorch v0.3.0
tqdm
visdom
seaborn
Pillow
scikit-learn

Training

Train a model with a confidence estimator with train.py. During training you can use visdom to see a histogram of confidence estimates from the test set. Training logs will be stored in the logs/ folder, while checkpoints are stored in the checkpoints/ folder.

Args Options Description
dataset cifar10,
svhn
Selects which dataset to train on.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use.
batch_size [int] Number of samples per batch.
epochs [int] Number of epochs for training.
seed [int] Random seed.
learning_rate [float] Learning rate.
data_augmentation Train with standard data augmentation (random flipping and translation).
cutout [int] Indicates the patch size to use for Cutout. If 0, Cutout is not used.
budget [float] Controls how often the network can choose have low confidence in its prediction. Increasing the budget will bias the output towards low confidence predictions, while decreasing the budget will produce more high confidence predictions.
baseline Train the model without the confidence branch.

Use the following settings to replicate the experiments from the paper:

VGG13 on CIFAR-10

python train.py --dataset cifar10 --model vgg13 --budget 0.3 --data_augmentation --cutout 16

WideResNet on CIFAR-10

python train.py --dataset cifar10 --model wideresnet --budget 0.3 --data_augmentation --cutout 16

DenseNet on CIFAR-10

python train.py --dataset cifar10 --model densenet --budget 0.3 --epochs 300 --batch_size 64 --data_augmentation --cutout 16

VGG13 on SVHN

python train.py --dataset svhn --model vgg13 --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

WideResNet on SVHN

python train.py --dataset svhn --model wideresnet --budget 0.3 --learning_rate 0.01 --epochs 160 --data_augmentation --cutout 20

DenseNet on SVHN

python train.py --dataset svhn --model densenet --budget 0.3 --learning_rate 0.01 --epochs 300 --batch_size 64  --data_augmentation --cutout 20

Out-of-distribution detection

Evaluate a trained model with out_of_distribution_detection.py. Before running this you will need to download the out-of-distribution datasets from Shiyu Liang's ODIN github repo and modify the data paths in the file according to where you saved the datasets.

Args Options Description
ind_dataset cifar10,
svhn
Indicates which dataset to use as in-distribution. Should be the same one that the model was trained on.
ood_dataset tinyImageNet_crop,
tinyImageNet_resize,
LSUN_crop,
LSUN_resize,
iSUN,
Uniform,
Gaussian,
all
Indicates which dataset to use as the out-of-distribution datset.
model densenet,
wideresnet,
vgg13
Selects which model architecture to use. Should be the same one that the model was trained on.
process baseline,
ODIN,
confidence,
confidence_scaling
Indicates which method to use for out-of-distribution detection. Baseline uses the maximum softmax probability. ODIN applies temperature scaling and input pre-processing to the baseline method. Confidence uses the learned confidence estimates. Confidence scaling applies input pre-processing to the confidence estimates.
batch_size [int] Number of samples per batch.
T [float] Temperature to use for temperature scaling.
epsilon [float] Noise magnitude to use for input pre-processing.
checkpoint [str] Filename of trained model checkpoint. Assumes the file is in the checkpoints/ folder. A .pt extension is also automatically added to the filename.
validation Use this flag for fine-tuning T and epsilon. If flag is on, the script will only evaluate on the first 1000 samples in the out-of-distribution dataset. If flag is not used, the remaining samples are used for evaluation. Based on validation procedure from ODIN.

Example commands for running the out-of-distribution detection script:

Baseline

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset all --model vgg13 --process baseline --checkpoint svhn_vgg13_budget_0.0_seed_0

ODIN

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset tinyImageNet_resize --model densenet --process ODIN --T 1000 --epsilon 0.001 --checkpoint cifar10_densenet_budget_0.0_seed_0

Confidence

python out_of_distribution_detection.py --ind_dataset cifar10 --ood_dataset LSUN_crop --model vgg13 --process confidence --checkpoint cifar10_vgg13_budget_0.3_seed_0

Confidence scaling

python out_of_distribution_detection.py --ind_dataset svhn --ood_dataset iSUN --model wideresnet --process confidence_scaling --epsilon 0.001 --checkpoint svhn_wideresnet_budget_0.3_seed_0
Official Matlab Implementation for "Tiny Obstacle Discovery by Occlusion-aware Multilayer Regression", TIP 2020

Tiny Obstacle Discovery by Occlusion-aware Multilayer Regression Official Matlab Implementation for "Tiny Obstacle Discovery by Occlusion-aware Multil

Xuefeng 5 Jan 15, 2022
A powerful framework for decentralized federated learning with user-defined communication topology

Scatterbrained Decentralized Federated Learning Scatterbrained makes it easy to build federated learning systems. In addition to traditional federated

Johns Hopkins Applied Physics Laboratory 7 Sep 26, 2022
Code for Learning to Segment The Tail (LST)

Learning to Segment the Tail [arXiv] In this repository, we release code for Learning to Segment The Tail (LST). The code is directly modified from th

47 Nov 07, 2022
existing and custom freqtrade strategies supporting the new hyperstrategy format.

freqtrade-strategies Description Existing and self-developed strategies, rewritten to support the new HyperStrategy format from the freqtrade-develop

39 Aug 20, 2021
A scikit-learn-compatible module for estimating prediction intervals.

MAPIE - Model Agnostic Prediction Interval Estimator MAPIE allows you to easily estimate prediction intervals (or prediction sets) using your favourit

588 Jan 04, 2023
This code provides a PyTorch implementation for OTTER (Optimal Transport distillation for Efficient zero-shot Recognition), as described in the paper.

Data Efficient Language-Supervised Zero-Shot Recognition with Optimal Transport Distillation This repository contains PyTorch evaluation code, trainin

Meta Research 45 Dec 20, 2022
PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)

1-bit Wide ResNet PyTorch implementation of training 1-bit Wide ResNets from this paper: Training wide residual networks for deployment using a single

Sergey Zagoruyko 122 Dec 07, 2022
Source code for the paper "SEPP: Similarity Estimation of Predicted Probabilities for Defending and Detecting Adversarial Text" PACLIC 2021

Adversarial text generator Refer to "adversarial_text_generator"[https://github.com/quocnsh/SEPP_generator] project for generating adversarial texts A

0 Oct 05, 2021
*ObjDetApp* deploys a pytorch model for object detection

*ObjDetApp* deploys a pytorch model for object detection

Will Chao 1 Dec 26, 2021
Complete* list of autonomous driving related datasets

AD Datasets Complete* and curated list of autonomous driving related datasets Contributing Contributions are very welcome! To add or update a dataset:

Daniel Bogdoll 13 Dec 19, 2022
The implementation of FOLD-R++ algorithm

FOLD-R-PP The implementation of FOLD-R++ algorithm. The target of FOLD-R++ algorithm is to learn an answer set program for a classification task. Inst

13 Dec 23, 2022
The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

Improved Techniques for Training Score-Based Generative Models This repo contains the official implementation for the paper Improved Techniques for Tr

174 Dec 26, 2022
Rayvens makes it possible for data scientists to access hundreds of data services within Ray with little effort.

Rayvens augments Ray with events. With Rayvens, Ray applications can subscribe to event streams, process and produce events. Rayvens leverages Apache

CodeFlare 32 Dec 25, 2022
YOLOV4运行在嵌入式设备上

在嵌入式设备上实现YOLO V4 tiny 在嵌入式设备上实现YOLO V4 tiny 目录结构 目录结构 |-- YOLO V4 tiny |-- .gitignore |-- LICENSE |-- README.md |-- test.txt |-- t

Liu-Wei 6 Sep 09, 2021
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

AutoViz and Auto_ViML 397 Dec 30, 2022
Code for "On Memorization in Probabilistic Deep Generative Models"

On Memorization in Probabilistic Deep Generative Models This repository contains the code necessary to reproduce the experiments in On Memorization in

The Alan Turing Institute 3 Jun 09, 2022
Simple embedding based text classifier inspired by fastText, implemented in tensorflow

FastText in Tensorflow This project is based on the ideas in Facebook's FastText but implemented in Tensorflow. However, it is not an exact replica of

Alan Patterson 306 Dec 02, 2022
Deep generative models of 3D grids for structure-based drug discovery

What is liGAN? liGAN is a research codebase for training and evaluating deep generative models for de novo drug design based on 3D atomic density grid

Matt Ragoza 152 Jan 03, 2023
RGB-stacking 🛑 🟩 🔷 for robotic manipulation

RGB-stacking 🛑 🟩 🔷 for robotic manipulation BLOG | PAPER | VIDEO Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes, Alex X. Lee*,

DeepMind 95 Dec 23, 2022
Simple, but essential Bayesian optimization package

BayesO: A Bayesian optimization framework in Python Simple, but essential Bayesian optimization package. http://bayeso.org Online documentation Instal

Jungtaek Kim 74 Dec 05, 2022