[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

Overview

MDCA Calibration

This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration".

Abstract

Deep Neural Networks (DNNs) make overconfident mistakes which can prove to be probematic in deployment in safety critical applications. Calibration is aimed to enhance trust in DNNs. The goal of our proposed Multi-Class Difference in Confidence and Accuracy (MDCA) loss is to align the probability estimates in accordance with accuracy thereby enhancing the trust in DNN decisions. MDCA can be used in case of image classification, image segmentation, and natural language classification tasks.

Teaser

Above image shows comparison of classwise reliability diagrams of Cross-Entropy vs. our proposed method.

Requirements

  • Python 3.8
  • PyTorch 1.8

Directly install using pip

 pip install -r requirements.txt

Training scripts:

Refer to the scripts folder to train for every model and dataset. Overall the command to train looks like below where each argument can be changed accordingly on how to train. Also refer to dataset/__init__.py and models/__init__.py for correct arguments to train with. Argument parser can be found in utils/argparser.py.

Train with cross-entropy:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss cross_entropy 

Train with FL+MDCA: Also mention the gamma (for Focal Loss) and beta (Weight assigned to MDCA) to train FL+MDCA with

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss FL+MDCA --gamma 1.0 --beta 1.0 

Train with NLL+MDCA:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss NLL+MDCA --beta 1.0

Post Hoc Calibration:

To do post-hoc calibration, we can use the following command.

lr and patience value is used for Dirichlet calibration. To change range of grid-search in dirichlet calibration, refer to posthoc_calibrate.py.

python posthoc_calibrate.py --dataset cifar10 --model resnet56 --lr 0.001 --patience 5 --checkpoint path/to/your/trained/model

Other Experiments (Dataset Drift, Dataset Imbalance):

experiments folder contains our experiments on PACS, Rotated MNIST and Imbalanced CIFAR10. Please refer to the scripts provided to run them.

Citation

If you find our work useful in your research, please cite the following:

@InProceedings{StitchInTime,
    author    = {R. Hebbalaguppe, J. Prakash, N. Madan, C. Arora},
    title     = {A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

Contact

For questions about our paper or code, please contact any of the authors (@neelabh17, @bicycleman15, @rhebbalaguppe ) or raise an issue on GitHub.

References:

The code is adapted from the following repositories:

[1] bearpaw/pytorch-classification [2] torrvision/focal_calibration [3] Jonathan-Pearce/calibration_library

Owner
MDCA Calibration
MDCA Calibration
A Human-in-the-Loop workflow for creating HD images from text

A Human-in-the-Loop? workflow for creating HD images from text DALL·E Flow is an interactive workflow for generating high-definition images from text

Jina AI 2.5k Jan 02, 2023
Official implementation of Influence-balanced Loss for Imbalanced Visual Classification in PyTorch.

Official implementation of Influence-balanced Loss for Imbalanced Visual Classification in PyTorch.

Seulki Park 70 Jan 03, 2023
Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”

Official implementation for TransDA Official pytorch implement for “Transformer-Based Source-Free Domain Adaptation”. Overview: Result: Prerequisites:

stanley 54 Dec 22, 2022
DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现

DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download

Bubbliiiing 31 Nov 25, 2022
Code for Transformer Hawkes Process, ICML 2020.

Transformer Hawkes Process Source code for Transformer Hawkes Process (ICML 2020). Run the code Dependencies Python 3.7. Anaconda contains all the req

Simiao Zuo 111 Dec 26, 2022
Repository for MDPGT

MD-PGT Repository for implementing and reproducing the results for the paper MDPGT: Momentum-based Decentralized Policy Gradient Tracking. Available E

Xian Yeow Lee 2 Dec 30, 2021
Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Oral)

CMT Code for paper Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Best Paper Award) [Paper] [Site] Directory Struc

Zhaokai Wang 198 Dec 27, 2022
The codebase for Data-driven general-purpose voice activity detection.

Data driven GPVAD Repository for the work in TASLP 2021 Voice activity detection in the wild: A data-driven approach using teacher-student training. S

Heinrich Dinkel 75 Nov 27, 2022
Process JSON files for neural recording sessions using Medtronic's BrainSense Percept PC neurostimulator

percept_processing This code processes JSON files for streamed neural data using Medtronic's Percept PC neurostimulator with BrainSense Technology for

Maria Olaru 3 Jun 06, 2022
This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Convolutional Networks on Node Classification

DropEdge: Towards Deep Graph Convolutional Networks on Node Classification This is a Pytorch implementation of paper: DropEdge: Towards Deep Graph Con

401 Dec 16, 2022
PyTorch implementation for the visual prior component (i.e. perception module) of the Visually Grounded Physics Learner [Li et al., 2020].

VGPL-Visual-Prior PyTorch implementation for the visual prior component (i.e. perception module) of the Visually Grounded Physics Learner (VGPL). Give

Toru 8 Dec 29, 2022
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
It's a powerful version of linebot

CTPS-FINAL Linbot-sever.py 主程式 Algorithm.py 推薦演算法,媒合餐廳端資料與顧客端資料 config.ini 儲存 channel-access-token、channel-secret 資料 Preface 生活在成大將近4年,我們每天的午餐時間看著形形色色

1 Oct 17, 2022
Code for reproducible experiments presented in KSD Aggregated Goodness-of-fit Test.

Code for KSDAgg: a KSD aggregated goodness-of-fit test This GitHub repository contains the code for the reproducible experiments presented in our pape

Antonin Schrab 5 Dec 15, 2022
Coursera - Quiz & Assignment of Coursera

Coursera Assignments This repository is aimed to help Coursera learners who have difficulties in their learning process. The quiz and programming home

浅梦 828 Jan 04, 2023
[CVPR 2021] Scan2Cap: Context-aware Dense Captioning in RGB-D Scans

Scan2Cap: Context-aware Dense Captioning in RGB-D Scans Introduction We introduce the task of dense captioning in 3D scans from commodity RGB-D sensor

Dave Z. Chen 79 Nov 07, 2022
Attention-guided gan for synthesizing IR images

SI-AGAN Attention-guided gan for synthesizing IR images This repository contains the Tensorflow code for "Pedestrian Gender Recognition by Style Trans

1 Oct 25, 2021
Learning Representations that Support Robust Transfer of Predictors

Transfer Risk Minimization (TRM) Code for Learning Representations that Support Robust Transfer of Predictors Prepare the Datasets Preprocess the Scen

Yilun Xu 15 Dec 07, 2022
[ICRA2021] Reconstructing Interactive 3D Scene by Panoptic Mapping and CAD Model Alignment

Interactive Scene Reconstruction Project Page | Paper This repository contains the implementation of our ICRA2021 paper Reconstructing Interactive 3D

97 Dec 28, 2022
Discretized Integrated Gradients for Explaining Language Models (EMNLP 2021)

Discretized Integrated Gradients for Explaining Language Models (EMNLP 2021) Overview of paths used in DIG and IG. w is the word being attributed. The

INK Lab @ USC 17 Oct 27, 2022