PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

Related tags

Deep LearningSENTRY
Overview

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation

Viraj Prabhu, Shivam Khare, Deeksha Kartik, Judy Hoffman

Many existing approaches for unsupervised domain adaptation (UDA) focus on adapting under only data distribution shift and offer limited success under additional cross-domain label distribution shift. Recent work based on self-training using target pseudolabels has shown promise, but on challenging shifts pseudolabels may be highly unreliable and using them for self-training may cause error accumulation and domain misalignment. We propose Selective Entropy Optimization via Committee Consistency (SENTRY), a UDA algorithm that judges the reliability of a target instance based on its predictive consistency under a committee of random image transformations. Our algorithm then selectively minimizes predictive entropy to increase confidence on highly consistent target instances, while maximizing predictive entropy to reduce confidence on highly inconsistent ones. In combination with pseudolabel-based approximate target class balancing, our approach leads to significant improvements over the state-of-the-art on 27/31 domain shifts from standard UDA benchmarks as well as benchmarks designed to stress-test adaptation under label distribution shift.

method

Table of Contents

Setup and Dependencies

  1. Create an anaconda environment with Python 3.6: conda create -n sentry python=3.6.8 and activate: conda activate sentry
  2. Navigate to the code directory: cd code/
  3. Install dependencies: pip install -r requirements.txt

And you're all set up!

Usage

Download data

Data for SVHN->MNIST is downloaded automatically via PyTorch. Data for other benchmarks can be downloaded from the following links. The splits used for our experiments are already included in the data/ folder):

  1. DomainNet
  2. OfficeHome
  3. VisDA2017 (only train and validation needed)

Pretrained checkpoints

To reproduce numbers reported in the paper, we include a a few pretrained checkpoints. We include checkpoints (source and adapted) for SVHN to MNIST (DIGITS) in the checkpoints directory. Source and adapted checkpoints for Clipart to Sketch adaptation (from DomainNet) and Real_World to Product adaptation (from OfficeHome RS-UT) can be downloaded from this link, and should be saved to the checkpoints/source and checkpoints/SENTRY directory as appropriate.

Train and adapt model

  • Natural label distribution shift: Adapt a model from to for a given (where benchmark may be DomainNet, OfficeHome, VisDA, or DIGITS), as follows:
python train.py --id <experiment_id> \
                --source <source> \
                --target <target> \
                --img_dir <image_directory> \
                --LDS_type <LDS_type> \
                --load_from_cfg True \
                --cfg_file 'config/<benchmark>/<cfg_file>.yml' \
                --use_cuda True

SENTRY hyperparameters are provided via a sentry.yml config file in the corresponding config/<benchmark> folder (On DIGITS, we also provide a config for baseline adaptation via DANN). The list of valid source/target domains per-benchmark are:

  • DomainNet: real, clipart, sketch, painting
  • OfficeHome_RS_UT: Real_World, Clipart, Product
  • OfficeHome: Real_World, Clipart, Product, Art
  • VisDA2017: visda_train, visda_test
  • DIGITS: Only svhn (source) to mnist (target) adaptation is currently supported.

Pass in the path to the parent folder containing dataset images via the --img_dir <name_of_directory> flag (eg. --img_dir '~/data/DomainNet'). Pass in the label distribution shift type via the --LDS_type flag: For DomainNet, OfficeHome (standard), and VisDA2017, pass in --LDS_type 'natural' (default). For OfficeHome RS-UT, pass in --LDS_type 'RS_UT'. For DIGITS, pass in --LDS_type as one of IF1, IF20, IF50, or IF100, to load a manually long-tailed target training split with a given imbalance factor (IF), as described in Table 4 of the paper.

To load a pretrained DA checkpoint instead of training your own, additionally pass --load_da True and --id <benchmark_name> to the script above. Finally, the training script will log performance metrics to the console (average and aggregate accuracy), and additionally plot and save some per-class performance statistics to the results/ folder.

Note: By default this code runs on GPU. To run on CPU pass: --use_cuda False

Reference

If you found this code useful, please consider citing:

@article{prabhu2020sentry
   author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy},
   title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation},
   year = {2020},
   journal = {arXiv preprint: 2012.11460},
}

Acknowledgements

We would like to thank the developers of PyTorch for building an excellent framework, in addition to the numerous contributors to all the open-source packages we use.

License

MIT

PhysCap: Physically Plausible Monocular 3D Motion Capture in Real Time

PhysCap: Physically Plausible Monocular 3D Motion Capture in Real Time The implementation is based on SIGGRAPH Aisa'20. Dependencies Python 3.7 Ubuntu

soratobtai 124 Dec 08, 2022
Pipeline code for Sequential-GAM(Genome Architecture Mapping).

Sequential-GAM Pipeline code for Sequential-GAM(Genome Architecture Mapping). mapping whole_preprocess.sh include the whole processing of mapping. usa

3 Nov 03, 2022
A C implementation for creating 2D voronoi diagrams

Branch OSX/Linux Windows master dev jc_voronoi A fast C/C++ header only implementation for creating 2D Voronoi diagrams from a point set Uses Fortune'

Mathias Westerdahl 481 Dec 29, 2022
Code for paper " AdderNet: Do We Really Need Multiplications in Deep Learning?"

AdderNet: Do We Really Need Multiplications in Deep Learning? This code is a demo of CVPR 2020 paper AdderNet: Do We Really Need Multiplications in De

HUAWEI Noah's Ark Lab 915 Jan 01, 2023
Code repository for "Reducing Underflow in Mixed Precision Training by Gradient Scaling" presented at IJCAI '20

Reducing Underflow in Mixed Precision Training by Gradient Scaling This project implements the gradient scaling method to improve the performance of m

Ruizhe Zhao 5 Apr 14, 2022
[CVPR'22] Weakly Supervised Semantic Segmentation by Pixel-to-Prototype Contrast

wseg Overview The Pytorch implementation of Weakly Supervised Semantic Segmentation by Pixel-to-Prototype Contrast. [arXiv] Though image-level weakly

Ye Du 96 Dec 30, 2022
Reproducing-BowNet: Learning Representations by Predicting Bags of Visual Words

Reproducing-BowNet Our reproducibility effort based on the 2020 ML Reproducibility Challenge. We are reproducing the results of this CVPR 2020 paper:

6 Mar 16, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network

Stock Price Prediction of Apple Inc. Using Recurrent Neural Network OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network Dataset:

Nouroz Rahman 410 Jan 05, 2023
wlad 2 Dec 19, 2022
El-Gamal on Elliptic Curve (Python)

El-Gamal-on-EC El-Gamal on Elliptic Curve (Python) References: https://docsdrive.com/pdfs/ansinet/itj/2005/299-306.pdf https://arxiv.org/ftp/arxiv/pap

3 May 04, 2022
The official repository for BaMBNet

BaMBNet-Pytorch Paper

Junjun Jiang 18 Dec 04, 2022
GANTheftAuto is a fork of the Nvidia's GameGAN

Description GANTheftAuto is a fork of the Nvidia's GameGAN, which is research focused on emulating dynamic game environments. The early research done

Harrison 801 Dec 27, 2022
The Submission for SIMMC 2.0 Challenge 2021

The Submission for SIMMC 2.0 Challenge 2021 challenge website Requirements python 3.8.8 pytorch 1.8.1 transformers 4.8.2 apex for multi-gpu nltk Prepr

5 Jul 26, 2022
Sleep staging from ECG, assisted with EEG

Sleep_Staging_Knowledge Distillation This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep

2 Dec 12, 2022
Machine learning algorithms for many-body quantum systems

NetKet NetKet is an open-source project delivering cutting-edge methods for the study of many-body quantum systems with artificial neural networks and

NetKet 413 Dec 31, 2022
SalGAN: Visual Saliency Prediction with Generative Adversarial Networks

SalGAN: Visual Saliency Prediction with Adversarial Networks Junting Pan Cristian Canton Ferrer Kevin McGuinness Noel O'Connor Jordi Torres Elisa Sayr

Image Processing Group - BarcelonaTECH - UPC 347 Nov 22, 2022
Unofficial implementation (replicates paper results!) of MINER: Multiscale Implicit Neural Representations in pytorch-lightning

MINER_pl Unofficial implementation of MINER: Multiscale Implicit Neural Representations in pytorch-lightning. 📖 Ref readings Laplacian pyramid explan

AI葵 51 Nov 28, 2022
GenGNN: A Generic FPGA Framework for Graph Neural Network Acceleration

GenGNN: A Generic FPGA Framework for Graph Neural Network Acceleration Stefan Abi-Karam*, Yuqi He*, Rishov Sarkar*, Lakshmi Sathidevi, Zihang Qiao, Co

Sharc-Lab 19 Dec 15, 2022
Official repository for the paper, MidiBERT-Piano: Large-scale Pre-training for Symbolic Music Understanding.

MidiBERT-Piano Authors: Yi-Hui (Sophia) Chou, I-Chun (Bronwin) Chen Introduction This is the official repository for the paper, MidiBERT-Piano: Large-

137 Dec 15, 2022