[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

Overview

MDCA Calibration

This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration".

Abstract

Deep Neural Networks (DNNs) make overconfident mistakes which can prove to be probematic in deployment in safety critical applications. Calibration is aimed to enhance trust in DNNs. The goal of our proposed Multi-Class Difference in Confidence and Accuracy (MDCA) loss is to align the probability estimates in accordance with accuracy thereby enhancing the trust in DNN decisions. MDCA can be used in case of image classification, image segmentation, and natural language classification tasks.

Teaser

Above image shows comparison of classwise reliability diagrams of Cross-Entropy vs. our proposed method.

Requirements

  • Python 3.8
  • PyTorch 1.8

Directly install using pip

 pip install -r requirements.txt

Training scripts:

Refer to the scripts folder to train for every model and dataset. Overall the command to train looks like below where each argument can be changed accordingly on how to train. Also refer to dataset/__init__.py and models/__init__.py for correct arguments to train with. Argument parser can be found in utils/argparser.py.

Train with cross-entropy:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss cross_entropy 

Train with FL+MDCA: Also mention the gamma (for Focal Loss) and beta (Weight assigned to MDCA) to train FL+MDCA with

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss FL+MDCA --gamma 1.0 --beta 1.0 

Train with NLL+MDCA:

python train.py --dataset cifar10 --model resnet56 --schedule-steps 80 120 --epochs 160 --loss NLL+MDCA --beta 1.0

Post Hoc Calibration:

To do post-hoc calibration, we can use the following command.

lr and patience value is used for Dirichlet calibration. To change range of grid-search in dirichlet calibration, refer to posthoc_calibrate.py.

python posthoc_calibrate.py --dataset cifar10 --model resnet56 --lr 0.001 --patience 5 --checkpoint path/to/your/trained/model

Other Experiments (Dataset Drift, Dataset Imbalance):

experiments folder contains our experiments on PACS, Rotated MNIST and Imbalanced CIFAR10. Please refer to the scripts provided to run them.

Citation

If you find our work useful in your research, please cite the following:

@InProceedings{StitchInTime,
    author    = {R. Hebbalaguppe, J. Prakash, N. Madan, C. Arora},
    title     = {A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

Contact

For questions about our paper or code, please contact any of the authors (@neelabh17, @bicycleman15, @rhebbalaguppe ) or raise an issue on GitHub.

References:

The code is adapted from the following repositories:

[1] bearpaw/pytorch-classification [2] torrvision/focal_calibration [3] Jonathan-Pearce/calibration_library

Owner
MDCA Calibration
MDCA Calibration
Code for Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works

GDAP Code for Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works Environment Python (verified: v3.8) CUDA

45 Oct 29, 2022
Code for our CVPR2021 paper coordinate attention

Coordinate Attention for Efficient Mobile Network Design (preprint) This repository is a PyTorch implementation of our coordinate attention (will appe

Qibin (Andrew) Hou 726 Jan 05, 2023
Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".

Detecting Twenty-thousand Classes using Image-level Supervision Detic: A Detector with image classes that can use image-level labels to easily train d

Meta Research 1.3k Jan 04, 2023
GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️

GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️ This repo contains a PyTorch implementation of the original GAT paper ( 🔗 Veličković et

Aleksa Gordić 1.9k Jan 09, 2023
Mmdetection3d Noted - MMDetection3D is an open source object detection toolbox based on PyTorch

MMDetection3D is an open source object detection toolbox based on PyTorch

Jiangjingwen 13 Jan 06, 2023
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Apache MXNet (incubating) for Deep Learning Master Docs License Apache MXNet (incubating) is a deep learning framework designed for both efficiency an

ROCm Software Platform 29 Nov 16, 2022
Trajectory Prediction with Graph-based Dual-scale Context Fusion

DSP: Trajectory Prediction with Graph-based Dual-scale Context Fusion Introduction This is the project page of the paper Lu Zhang, Peiliang Li, Jing C

HKUST Aerial Robotics Group 103 Jan 04, 2023
Food recognition model using convolutional neural network & computer vision

Food recognition model using convolutional neural network & computer vision. The goal is to match or beat the DeepFood Research Paper

Hemanth Chandran 1 Jan 13, 2022
The Noise Contrastive Estimation for softmax output written in Pytorch

An NCE implementation in pytorch About NCE Noise Contrastive Estimation (NCE) is an approximation method that is used to work around the huge computat

Kaiyu Shi 287 Nov 25, 2022
Flow is a computational framework for deep RL and control experiments for traffic microsimulation.

Flow Flow is a computational framework for deep RL and control experiments for traffic microsimulation. See our website for more information on the ap

867 Jan 02, 2023
Like ThreeJS but for Python and based on wgpu

pygfx A render engine, inspired by ThreeJS, but for Python and targeting Vulkan/Metal/DX12 (via wgpu). Introduction This is a Python render engine bui

139 Jan 07, 2023
Code for the AAAI-2022 paper: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification (AAAI 2022) Prerequisite PyTorch = 1.2.0 P

16 Dec 14, 2022
SASM - simple crossplatform IDE for NASM, MASM, GAS and FASM assembly languages

SASM (SimpleASM) - простая кроссплатформенная среда разработки для языков ассемблера NASM, MASM, GAS, FASM с подсветкой синтаксиса и отладчиком. В SA

Dmitriy Manushin 5.6k Jan 06, 2023
Rank1 Conversation Emotion Detection Task

Rank1-Conversation_Emotion_Detection_Task accuracy macro-f1 recall 0.826 0.7544 0.719 基于预训练模型和时序预测模型的对话情感探测任务 1 摘要 针对对话情感探测任务,本文将其分为文本分类和时间序列预测两个子任务,分

Yuchen Han 2 Nov 28, 2021
existing and custom freqtrade strategies supporting the new hyperstrategy format.

freqtrade-strategies Description Existing and self-developed strategies, rewritten to support the new HyperStrategy format from the freqtrade-develop

39 Aug 20, 2021
Pytorch implementation of ProjectedGAN

ProjectedGAN-pytorch Pytorch implementation of ProjectedGAN (https://arxiv.org/abs/2111.01007) Note: this repository is still under developement. @InP

Dominic Rampas 17 Dec 14, 2022
Online-compatible Unsupervised Non-resonant Anomaly Detection Repository

Online-compatible Unsupervised Non-resonant Anomaly Detection Repository Repository containing all scripts used in the studies of Online-compatible Un

0 Nov 09, 2021
A Moonraker plug-in for real-time compensation of frame thermal expansion

Frame Expansion Compensation A Moonraker plug-in for real-time compensation of frame thermal expansion. Installation Credit to protoloft, from whom I

58 Jan 02, 2023
COVID-Net Open Source Initiative

The COVID-Net models provided here are intended to be used as reference models that can be built upon and enhanced as new data becomes available

Linda Wang 1.1k Dec 26, 2022
Contrastive Fact Verification

VitaminC This repository contains the dataset and models for the NAACL 2021 paper: Get Your Vitamin C! Robust Fact Verification with Contrastive Evide

47 Dec 19, 2022