Tilted Empirical Risk Minimization (ICLR '21)

Overview

Tilted Empirical Risk Minimization

This repository contains the implementation for the paper

Tilted Empirical Risk Minimization

ICLR 2021

Empirical risk minimization (ERM) is typically designed to perform well on the average loss, which can result in estimators that are sensitive to outliers, generalize poorly, or treat subgroups unfairly. While many methods aim to address these problems individually, in this work, we explore them through a unified framework---tilted empirical risk minimization (TERM).

This repository contains the data, code, and experiments to reproduce our empirical results. We demonstrate that TERM can be used for a multitude of applications, such as enforcing fairness between subgroups, mitigating the effect of outliers, and handling class imbalance. TERM is not only competitive with existing solutions tailored to these individual problems, but can also enable entirely new applications, such as simultaneously addressing outliers and promoting fairness.

Getting started

Dependencies

As we apply TERM to a diverse set of real-world applications, the dependencies for different applications can be different.

  • if we mention that the code is based on other public codebases, then one needs to follow the same setup of those codebases.
  • otherwise, need the following dependencies (the latest versions will work):
    • python3
    • sklearn
    • numpy
    • matplotlib
    • colorsys
    • seaborn
    • scipy
    • cvxpy (optional)

Properties of TERM

Motivating examples

These figures illustrate TERM as a function of t: (a) finding a point estimate from a set of 2D samples, (b) linear regression with outliers, and (c) logistic regression with imbalanced classes. While positive values of t magnify outliers, negative values suppress them. Setting t=0 recovers the original ERM objective.

(How to generate these figures: cd TERM/toy_example & jupyter notebook , and directly run the three notebooks.)

A toy problem to visualize the solutions to TERM

TERM objectives for a squared loss problem with N=3. As t moves from - to +, t-tilted losses recover min-loss (t-->+), avg-loss (t=0), and max-loss (t-->+), and approximate median-loss (for some t). TERM is smooth for all finite t and convex for positive t.

(How to generate this figure: cd TERM/properties & jupyter notebook , and directly run the notebook.)

How to run the code for different applications

1. Robust regression

cd TERM/robust_regression
python regression.py --obj $OBJ --corrupt 1 --noise $NOISE

where $OBJ is the objective and $NOISE is the noise level (see code for options).

2. Robust classification

cd TERM/robust_classification

3. Mitigating noisy annotators

cd TERM/noisy_annotator/pytorch_resnet_cifar10
python trainer.py --t -2  # TERM

4. Fair PCA

cd TERM/fair_pca
jupyter notebook

and directly run the notebook fair_pca_credit.ipynb.

  • built upon the public fair pca codebase
  • we directly extract the pre-processed Credit data dumped from the original matlab code, which are called data.csv, A.csv, and B.csv saved under TERM/fair_pca/multi-criteria-dimensionality-reduction-master/data/credit/.
  • dependencies: same as the fair pca code

5. Handling class imbalance

cd TERM/class_imbalance
python3 -m mnist.mnist_train_tilting --exp tilting  # TERM, common class=99.5%

6. Variance reduction for generalization

cd TERM/DRO
python variance_reduction.py --obj $OBJ $OTHER_PARAS  

where $OBJ is the objective, and $OTHER_PARAS$ are the hyperparameters associated with the objective (see code for options). We report how we select the hyperparameters along with all hyperparameter values in Appendix E of the paper. For instance, for TERM with t=50, run the following:

python variance_reduction.py --obj tilting --t 50  

7. Fair federated learning

cd TERM/fair_flearn
bash run.sh tilting 0 0 term_t0.1_seed0 > term_t0.1_seed0 2>&1 &

8. Hierarchical multi-objective tilting

cd TERM/hierarchical
python mixed_level1.py --imbalance 1 --corrupt 1 --obj tilting --t_in -2 --t_out 10  # TERM_sc
python mixed_level2.py --imbalance 1 --corrupt 1 --obj tilting --t_in 50 --t_out -2 # TERM_ca
  • mixed_level1.py: TERM_{sc}: (sample level, class level)
  • mixed_level2.py: TERM_{ca}: (class level, annotator level)

References

Please see the paper for more details of TERM as well as a complete list of related work.

Owner
Tian Li
Tian Li
Optimized primitives for collective multi-GPU communication

NCCL Optimized primitives for inter-GPU communication. Introduction NCCL (pronounced "Nickel") is a stand-alone library of standard communication rout

NVIDIA Corporation 2k Jan 09, 2023
[ICML 2021] DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning | 斗地主AI

[ICML 2021] DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning DouZero is a reinforcement learning framework for DouDizhu (斗地主), t

Kwai Inc. 3.1k Jan 04, 2023
Unsupervised Semantic Segmentation by Contrasting Object Mask Proposals.

Unsupervised Semantic Segmentation by Contrasting Object Mask Proposals This repo contains the Pytorch implementation of our paper: Unsupervised Seman

Wouter Van Gansbeke 335 Dec 28, 2022
ULMFiT for Genomic Sequence Data

Genomic ULMFiT This is an implementation of ULMFiT for genomics classification using Pytorch and Fastai. The model architecture used is based on the A

Karl 276 Dec 12, 2022
Variational Attention: Propagating Domain-Specific Knowledge for Multi-Domain Learning in Crowd Counting (ICCV, 2021)

DKPNet ICCV 2021 Variational Attention: Propagating Domain-Specific Knowledge for Multi-Domain Learning in Crowd Counting Baseline of DKPNet is availa

19 Oct 14, 2022
A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swar.

Omni-swarm A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swarm Introduction Omni-swarm is a decentralized omn

HKUST Aerial Robotics Group 99 Dec 23, 2022
As-ViT: Auto-scaling Vision Transformers without Training

As-ViT: Auto-scaling Vision Transformers without Training [PDF] Wuyang Chen, Wei Huang, Xianzhi Du, Xiaodan Song, Zhangyang Wang, Denny Zhou In ICLR 2

VITA 68 Sep 05, 2022
Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics

Dataset Cartography Code for the paper Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics at EMNLP 2020. This repository cont

AI2 125 Dec 22, 2022
Vehicle speed detection with python

Vehicle-speed-detection In the project simulate the tracker.py first then simulate the SpeedDetector.py. Finally, a new window pops up and the output

3 Dec 15, 2022
Virtual hand gesture mouse using a webcam

NonMouse 日本語のREADMEはこちら This is an application that allows you to use your hand itself as a mouse. The program uses a web camera to recognize your han

Yuki Takeyama 55 Jan 01, 2023
Matthew Colbrook 1 Apr 08, 2022
A Broader Picture of Random-walk Based Graph Embedding

Random-walk Embedding Framework This repository is a reference implementation of the random-walk embedding framework as described in the paper: A Broa

Zexi Huang 23 Dec 13, 2022
[CVPR 2021] Region-aware Adaptive Instance Normalization for Image Harmonization

RainNet — Official Pytorch Implementation Region-aware Adaptive Instance Normalization for Image Harmonization Jun Ling, Han Xue, Li Song*, Rong Xie,

130 Dec 11, 2022
Official repository for Natural Image Matting via Guided Contextual Attention

GCA-Matting: Natural Image Matting via Guided Contextual Attention The source codes and models of Natural Image Matting via Guided Contextual Attentio

Li Yaoyi 349 Dec 26, 2022
Course content and resources for the AIAIART course.

AIAIART course This repo will house the notebooks used for the AIAIART course. Part 1 (first four lessons) ran via Discord in September/October 2021.

Jonathan Whitaker 492 Jan 06, 2023
MPI Interest Group on Algorithms on 1st semester 2021

MPI Algorithms Interest Group Introduction Lecturer: Steve Yan Location: TBA Time Schedule: TBA Semester: 1 Useful URLs Typora: https://typora.io Goog

Ex10si0n 13 Sep 08, 2022
SW components and demos for visual kinship recognition. An emphasis is put on the FIW dataset-- data loaders, benchmarks, results in summary.

FIW Data Development Kit Table of Contents Introduction Families In the Wild Database Publications Organization To Do License Getting Involved Introdu

Joseph P. Robinson 12 Jun 04, 2022
Model Quantization Benchmark

Introduction MQBench is an open-source model quantization toolkit based on PyTorch fx. The envision of MQBench is to provide: SOTA Algorithms. With MQ

500 Jan 06, 2023
Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images

Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images In this paper, we present an effective Dynamic Enhancement Anchor

13 Dec 09, 2022
Implementation of ICCV19 Paper "Learning Two-View Correspondences and Geometry Using Order-Aware Network"

OANet implementation Pytorch implementation of OANet for ICCV'19 paper "Learning Two-View Correspondences and Geometry Using Order-Aware Network", by

Jiahui Zhang 225 Dec 05, 2022