Understanding the Generalization Benefit of Model Invariance from a Data Perspective

Overview

Understanding the Generalization Benefit of Model Invariance from a Data Perspective

This is the code for our NeurIPS2021 paper "Understanding the Generalization Benefit of Model Invariance from a Data Perspective". There are two major parts in our code: sample covering number estimation and generalization benefit evaluation.

Requirments

  • Python 3.8
  • PyTorch
  • torchvision
  • scikit-learn-extra
  • scipy
  • robustness package (already included in our code)

Our code is based on robustness package.

Dataset

  • CIFAR-10 Download and extract the data into /data/cifar10
  • R2N2 Download the ShapeNet rendered images and put the data into /data/r2n2

The randomly sampled R2N2 images used for computing sample covering numbers and indices of examples for different sample sizes could be found here.

Estimation of sample covering numbers

To estimate the sample covering numbers of different data transformations, run the following script in /scn.

CUDA_VISIBLE_DEVICES=0 python run_scn.py  --epsilon 3 --transformation crop --cover_number_method fast --data-path /path/to/dataset 

Note that the input is a N x C x H x W tensor where N is sample size.

Evaluation of generalization benefit

To train the model with data augmentation method, run the following script in /learn_invariance for R2N2 dataset

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset r2n2 \
    --data ../data/2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --transforms view  \
    --inv-method aug \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name view

or the following script for CIFAR-10 dataset

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset cifar \
    --data ../data/cifar10 \
    --n-per-class all \
    --transforms crop  \
    --inv-method aug \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name crop 

By setting --transforms to be one of {none, flip, crop, rotate, view}, the specific transformation will be considered.

To train the model with regularization method, run the following script. Currently, the code only support 3d-view transformation on R2N2 dataset.

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset r2n2 \
    --data ../data/r2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --transforms view  \
    --inv-method reg \
    --inv-method-beta 1 \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name reg_view 

To evaluate the model with invariance loss and worst-case consistency accuracy, run the following script.

CUDA_VISIBLE_DEVICES=0 python main.py  \
    --dataset r2n2 \
    --data ../data/r2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --inv-method reg \
    --arch resnet18 \
    --resume /path/to/checkpoint.pt.best \
    --eval-only 1 \
    --transforms view  \
    --adv-eval 0 \
    --batch-size 2  \
    --no-store 

Note that to have the worst-case consistency accuracy we need to load 24 view images in R2N2RenderingsTorch class in dataset_3d.py.

Owner
PhD student at University of Maryland
a reimplementation of UnFlow in PyTorch that matches the official TensorFlow version

pytorch-unflow This is a personal reimplementation of UnFlow [1] using PyTorch. Should you be making use of this work, please cite the paper according

Simon Niklaus 134 Nov 20, 2022
The final project of "Applying AI to EHR Data" of "AI for Healthcare" nanodegree - Udacity.

Patient Selection for Diabetes Drug Testing Project Overview EHR data is becoming a key source of real-world evidence (RWE) for the pharmaceutical ind

Omar Laham 1 Jan 14, 2022
Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks

flownet2-pytorch Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks. Multiple GPU training is supported, a

NVIDIA Corporation 2.8k Dec 27, 2022
[CVPR 2021 Oral] Variational Relational Point Completion Network

VRCNet: Variational Relational Point Completion Network This repository contains the PyTorch implementation of the paper: Variational Relational Point

PL 121 Dec 12, 2022
Hypersim: A Photorealistic Synthetic Dataset for Holistic Indoor Scene Understanding

The Hypersim Dataset For many fundamental scene understanding tasks, it is difficult or impossible to obtain per-pixel ground truth labels from real i

Apple 1.3k Jan 04, 2023
Program your own vulkan.gpuinfo.org query in Python. Used to determine baseline hardware for WebGPU.

query-gpuinfo-data License This software is not presently released under a license. The data in data/ is obtained under CC BY 4.0 as specified there.

Kai Ninomiya 5 Jul 18, 2022
Let's create a tool to convert Thailand budget from PDF to CSV.

thailand-budget-pdf2csv Let's create a tool to convert Thailand Government Budgeting from PDF to CSV! รวมพลัง Dev แปลงงบ จาก PDF สู่ Machine-readable

Kao.Geek 88 Dec 19, 2022
An introduction to satellite image analysis using Python + OpenCV and JavaScript + Google Earth Engine

A Gentle Introduction to Satellite Image Processing Welcome to this introductory course on Satellite Image Analysis! Satellite imagery has become a pr

Edward Oughton 32 Jan 03, 2023
an implementation of 3D Ken Burns Effect from a Single Image using PyTorch

3d-ken-burns This is a reference implementation of 3D Ken Burns Effect from a Single Image [1] using PyTorch. Given a single input image, it animates

Simon Niklaus 1.4k Dec 28, 2022
A Python library for unevenly-spaced time series analysis

traces A Python library for unevenly-spaced time series analysis. Why? Taking measurements at irregular intervals is common, but most tools are primar

Datascope Analytics 516 Dec 29, 2022
Simulation of Self Driving Car

In this repository, the code to use Udacity's self driving car simulator as a testbed for training an autonomous car are provided.

Shyam Das Shrestha 1 Nov 21, 2021
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

73 Jan 01, 2023
SOTR: Segmenting Objects with Transformers [ICCV 2021]

SOTR: Segmenting Objects with Transformers [ICCV 2021] By Ruohao Guo, Dantong Niu, Liao Qu, Zhenbo Li Introduction This is the official implementation

186 Dec 20, 2022
A Python type explainer!

typesplainer A Python typehint explainer! Available as a cli, as a website, as a vscode extension, as a vim extension Usage First, install the package

Typesplainer 79 Dec 01, 2022
Multi agent DDPG algorithm written in Python + Pytorch

Multi agent DDPG algorithm written in Python + Pytorch. It also includes a Jupyter notebook, Tennis.ipynb, as a showcase.

Rogier Wachters 2 Feb 26, 2022
Official PyTorch implementation for paper "Efficient Two-Stage Detection of Human–Object Interactions with a Novel Unary–Pairwise Transformer"

UPT: Unary–Pairwise Transformers This repository contains the official PyTorch implementation for the paper Frederic Z. Zhang, Dylan Campbell and Step

Frederic Zhang 109 Dec 20, 2022
Public implementation of the Convolutional Motif Kernel Network (CMKN) architecture

CMKN Implementation of the convolutional motif kernel network (CMKN) introduced in Ditz et al., "Convolutional Motif Kernel Network", 2021. Testing Yo

1 Nov 17, 2021
Official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021.

Introduction This repository is the official PyTorch implementation of Data-free Knowledge Distillation for Object Detection, WACV 2021. Data-free Kno

NVIDIA Research Projects 50 Jan 05, 2023
Command-line tool for downloading and extending the RedCaps dataset.

RedCaps Downloader This repository provides the official command-line tool for downloading and extending the RedCaps dataset. Users can seamlessly dow

RedCaps dataset 33 Dec 14, 2022
Fake News Detection Using Machine Learning Methods

Fake-News-Detection-Using-Machine-Learning-Methods Fake news is always a real and dangerous issue. However, with the presence and abundance of various

Achraf Safsafi 1 Jan 11, 2022