Code accompanying the paper on "An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers" published at NeurIPS, 2021

Overview

Code for "An Empirical Investigation of Domian Generalization with Empirical Risk Minimizers" (NeurIPS 2021)

Motivation and Introduction

Domain Generalization is a task in machine learning where given a shift in the input data distribution, one is expected to perform well on a test task with a different input data distribution. For example, one might train a digit classifier on MNIST data and ask the model to generalize to predict digits that are rotated by say 30 degrees.

While many approaches have been proposed for this problem, we were intrigued by the results on the DomainBed benchmark which suggested that using the simple, empirical risk minimization (ERM) with a proper hyperparameter sweep leads to performance close to state of the art on Domain Generalization Problems.

What governs the generalization of a trained deep learning model using ERM to a given data distribution? This is the question we seek to answer in our NeurIPS 2021 paper:

An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers. Rama Vedantam, David Lopez-Paz*, David Schwab*.

NeurIPS 2021 (*=Equal Contribution)

This repository contains code used for producing the results in our paper.

Initial Setup

  1. Run source init.sh to install all the dependencies for the project. This will also initialize DomainBed as a submodule for the project

  2. Set requisite paths in setup.sh, and run source setup.sh

Computing Generalization Measures

  • Get set up with the DomainBed codebase and launch a sweep for an initial set of trained models (illustrated below for rotated MNIST dataset):
cd DomainBed/

python -m domainbed.scripts.sweep launch\
       --data_dir=${DOMAINBED_DATA} \
       --output_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
       --algorithms=ERM \
       --holdout_fraction=0.5\
       --datasets=RotatedMNIST \
       --n_hparams=1\
       --command_launcher submitit

After this step, we have a set of trained models that we can now look to evaluate and measure. Note that unlike the original domainbed paper we holdout a larger fraction (50%) of the data for evaluation of the measures.

  • Once the sweep finishes, aggregate the different files for use by the domianbed_measures codebase:
python domainbed_measures/write_job_status_file.py \
                --sweep_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
                --output_txt="domainbed_measures/scratch/sweep_release.txt"
  • Once this step is complete, we can compute various generalization measures and store them to disk for future analysis using:
SLURM_PARTITION="TO_BE_SET"
python domainbed_measures/compute_gen_correlations.py \
	--algorithm=ERM \
    --job_done_file="domainbed_measures/scratch/sweep_release.txt" \
    --run_dir=${MEASURE_RUN_DIR} \
    --all_measures_one_job \
	--slurm_partition=${SLURM_PARTITION}

Where we utilize slurm on a compute cluster to scale the experiments to thousands of models. If you do not have access to such a cluster with multiple GPUs to parallelize the computation, use --slurm_partition="" above and the code will run on a single GPU (although the results might take a long time to compute!).

  • Finally, once the above code is done, use the following code snippet to aggregate the values of the different generalization measures:
python domainbed_measures/extract_generalization_features.py \
    --run_dir=${MEASURE_RUN_DIR} \
    --sweep_name="_out_ERM_RotatedMNIST"

This step yeilds .csv files where each row corresponds to a given trained model. Each row overall has the following format:

dataset | test_envs | measure 1 | measure 2 | measure 3 | target_err

where:

  • test_envs specifies which environments the model is tested on or equivalently trained on, since the remaining environments are used for training
  • target_err specifies the target error value for regression
  • measure 1 specifies the which measure is being computed, e.g. sharpness or fisher eigen value based measures

In case of the file named, for example, sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv, the validation error within domain wd_out_domain_err is also used as one of the measures and target_err is the out of domain generalization error, and all measures are computed on a held-out set of image inputs from the target domain (for more details see the paper).

Alternatively, in case of the file named, sweeps__out_ERM_RotatedMNIST_canon_False_wd.csv, the target_err is the validation accuracy in domain, and all the measures are computed on the in-distribution held-out images.

  • Using this file one can do a number of interesting regression analyses as reported in the paper for measuring generalization.

For example, to generate the kind of results in Table. 1 of the paper in the joint setting, run the following command options:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="joint"\
    --num_features=2 \
    --fix_one_feature_to_wd

Alternatively, to generate results in the stratified setting, run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="stratified"\
    --num_features=2 \
    --fix_one_feature_to_wd

Finally, to generate results using a single feature (Alone setting in Table. 1), run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --num_features=1

Translation of measures from the code to the paper

The following table illustrates all the measures in the paper (Appendix Table. 2) and how they are referred to in the codebase:

Measure Name Code Reference
H-divergence c2st
H-divergence + Source Error c2st_perr
H-divergence MS c2st_per_env
H-divergence MS + Source Error c2st_per_env_perr
H-divergence (train) c2st_train
H-divergence (train) + Source Error c2st_train_perr
H-divergence (train) MS c2st_train_per_env
Entropy-Source or Entropy entropy
Entropy-Target entropy_held_out
Fisher-Eigval-Diff fisher_eigval_sum_diff_ex_75
Fisher-Eigval fisher_eigval_sum_ex_75
Fisher-Align or Fisher (main paper) fisher_eigvec_align_ex_75
HΔH-divergence SS hdh
HΔH-divergence SS + Source Error hdh_perr
HΔH-divergence MS hdh_per_env
HΔH-divergence MS + Source Error hdh_per_env_perr
HΔH-divergence (train) SS hdh_train
HΔH-divergence (train) SS + Source Error hdh_train_perr
Jacobian jacobian_norm
Jacobian Ratio jacobian_norm_relative
Jacobian Diff jacobian_norm_relative_diff
Jacobian Log Ratio jacobian_norm_relative_log_diff
Mixup mixup
Mixup Ratio mixup_relative
Mixup Diff mixup_relative_diff
Mixup Log Ratio mixup_relative_log_diff
MMD-Gaussian mmd_gaussian
MMD-Mean-Cov mmd_mean_cov
L2-Path-Norm. path_norm
Sharpness sharp_mag
H+-divergence SS v_plus_c2st
H+-divergence SS + Source Error v_plus_c2st_perr
H+-divergence MS v_plus_c2st_per_env
H+-divergence MS + Source Error v_plus_c2st_per_env_perr
H+ΔH+-divergence SS v_plus_hdh
H+ΔH+-divergence SS + Source Error v_plus_hdh_perr
H+ΔH+-divergence MS v_plus_hdh_per_env
H+ΔH+-divergence MS + Source Error v_plus_hdh_per_env_perr
Source Error wd_out_domain_err

Acknowledgments

We thank the developers of Decodable Information Bottleneck, Domain Bed and Jonathan Frankle for code we found useful for this project.

License

This source code is released under the Creative Commons Attribution-NonCommercial 4.0 International license, included here.

Owner
Meta Research
Meta Research
Pytorch codes for Feature Transfer Learning for Face Recognition with Under-Represented Data

FTLNet_Pytorch Pytorch codes for Feature Transfer Learning for Face Recognition with Under-Represented Data 1. Introduction This repo is an unofficial

1 Nov 04, 2020
Generic U-Net Tensorflow implementation for image segmentation

Tensorflow Unet Warning This project is discontinued in favour of a Tensorflow 2 compatible reimplementation of this project found under https://githu

Joel Akeret 1.8k Dec 10, 2022
Source code for "Pack Together: Entity and Relation Extraction with Levitated Marker"

PL-Marker Source code for Pack Together: Entity and Relation Extraction with Levitated Marker. Quick links Overview Setup Install Dependencies Data Pr

THUNLP 173 Dec 30, 2022
DCSAU-Net: A Deeper and More Compact Split-Attention U-Net for Medical Image Segmentation

DCSAU-Net: A Deeper and More Compact Split-Attention U-Net for Medical Image Segmentation By Qing Xu, Wenting Duan and Na He Requirements pytorch==1.1

Qing Xu 20 Dec 09, 2022
SCU OlympicsRunning Baseline

Competition 1v1 running Environment check details in Jidi Competition RLChina2021智能体竞赛 做出的修改: 奖励重塑:修改了环境,重新设置了奖励的分配,使得奖励组成不只有零和博弈,还有探索环境的奖励。 算法微调:修改了官

ZiSeoi Wong 2 Nov 23, 2021
CIFAR-10 Photo Classification

Image-Classification CIFAR-10 Photo Classification CIFAR-10_Dataset_Classfication CIFAR-10 Photo Classification Dataset CIFAR is an acronym that stand

ADITYA SHAH 1 Jan 05, 2022
Cleaned up code for DSTC 10: SIMMC 2.0 track: subtask 2: multimodal coreference resolution

UNITER-Based Situated Coreference Resolution with Rich Multimodal Input: arXiv MMCoref_cleaned Code for the MMCoref task of the SIMMC 2.0 dataset. Pre

Yichen (William) Huang 2 Dec 05, 2022
The world's largest toxicity dataset.

The Toxicity Dataset by Surge AI Saving the internet is fun. Combing through thousands of online comments to build a toxicity dataset isn't. That's wh

Surge AI 134 Dec 19, 2022
This repo implements several applications of the proposed generalized Bures-Wasserstein (GBW) geometry on symmetric positive definite matrices.

GBW This repo implements several applications of the proposed generalized Bures-Wasserstein (GBW) geometry on symmetric positive definite matrices. Ap

Andi Han 0 Oct 22, 2021
EDCNN: Edge enhancement-based Densely Connected Network with Compound Loss for Low-Dose CT Denoising

EDCNN: Edge enhancement-based Densely Connected Network with Compound Loss for Low-Dose CT Denoising By Tengfei Liang, Yi Jin, Yidong Li, Tao Wang. Th

workingcoder 115 Jan 05, 2023
Checking fibonacci - Generating the Fibonacci sequence is a classic recursive problem

Fibonaaci Series Generating the Fibonacci sequence is a classic recursive proble

Moureen Caroline O 1 Feb 15, 2022
Repo for the Tutorials of Day1-Day3 of the Nordic Probabilistic AI School 2021 (https://probabilistic.ai/)

ProbAI 2021 - Probabilistic Programming and Variational Inference Tutorial with Pryo Day 1 (June 14) Slides Notebook: students_PPLs_Intro Notebook: so

PGM-Lab 46 Nov 01, 2022
🚀 PyTorch Implementation of "Progressive Distillation for Fast Sampling of Diffusion Models(v-diffusion)"

PyTorch Implementation of "Progressive Distillation for Fast Sampling of Diffusion Models(v-diffusion)" Unofficial PyTorch Implementation of Progressi

Vitaliy Hramchenko 58 Dec 19, 2022
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 08, 2022
OCR-D wrapper for detectron2 based segmentation models

ocrd_detectron2 OCR-D wrapper for detectron2 based segmentation models Introduction Installation Usage OCR-D processor interface ocrd-detectron2-segm

Robert Sachunsky 13 Dec 06, 2022
VIL-100: A New Dataset and A Baseline Model for Video Instance Lane Detection (ICCV 2021)

Preparation Please see dataset/README.md to get more details about our datasets-VIL100 Please see INSTALL.md to install environment and evaluation too

82 Dec 15, 2022
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
Data augmentation for NLP, accepted at EMNLP 2021 Findings

AEDA: An Easier Data Augmentation Technique for Text Classification This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Techni

Akbar Karimi 81 Dec 09, 2022
Create images and texts with the First Order Generative Adversarial Networks

First Order Divergence for training GANs This repository contains code accompanying the paper First Order Generative Advesarial Netoworks The majority

Zalando Research 35 Dec 11, 2021