PyTorch code of my WACV 2022 paper Improving Model Generalization by Agreement of Learned Representations from Data Augmentation

Overview

Improving Model Generalization by Agreement of Learned Representations from Data Augmentation (WACV 2022)

Paper

ArXiv

Why it matters?

When data augmentation is applied on an input image, a model is forced to learn invariant features to improve model generalization (Figure 1).

Since data augmentation incurs little overhead, why not generate 2 data augmented images (also known as 2 positive samples) from a given input. Then, force the model to agree on the common invariant features to support the correct label (Figure 2). It turns out that maximizing this agreement further improves model model generalization. We call our method AgMax.

Unlike label smoothing, AgMax consistently improves model accuracy. For example on ImageNet1k for 90 epochs, the ResNet50 performance is as follows:

Data Augmentation Baseline Label Smoothing AgMax (Ours)
Standard 76.4 76.8 76.9
CutOut 76.2 76.5 77.1
MixUp 76.5 76.7 77.6
CutMix 76.3 76.4 77.4
AutoAugment (AA) 76.2 76.2 77.1
CutOut+AA 75.7 75.7 76.6
MixUp+AA 75.9 76.5 77.1
CutMix+AA 75.5 75.5 77.0

The figure below demonstrates consistent improvement across different data augmnentation methods:

Install requirements

pip3 install -r requirements.txt

Train

For example, train ResNet50 with AgMax on 2 GPUs for 90 epochs, SGD with lr=0.1 and multistep learning rate scheduler:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard-agmax --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Compare the results without AgMax:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Test

Using a pre-trained model:

ResNet101 trained with CutMix, AutoAugment and AgMax:

mkdir checkpoints
cd checkpoints
wget https://github.com/roatienza/agmax/releases/download/agmax-0.1.0/imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth
cd ..
python3 main.py --config=ResNet101-auto_augment-cutmix-agmax --eval \
--dataset=imagenet \
--resume imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth

ResNet50 trained with CutMix, AutoAugment and AgMax:

python3 main.py --config=ResNet50-auto_augment-cutmix-agmax --eval --n-units=2048 \
--dataset=imagenet --resume imagenet-agmax-ResNet50-cutmix-auto_augment-79.12-mlp-2048.pth

Other pre-trained models (Baselines):

Citation

If you find this work useful, please cite:

@inproceedings{atienza2022agmax,
  title={Improving Model Generalization by Agreement of Learned Representations from Data Augmentation},
  author={Atienza, Rowel},
  booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision},
  year={2022},
  pubstate={published},
  tppubtype={inproceedings}
}
You might also like...
Code for the AAAI-2022 paper: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification (AAAI 2022) Prerequisite PyTorch = 1.2.0 P

Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Sharpness-Aware Minimization for Efficiently Improving Generalization
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022
The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022

DG-TrajGen The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022. Our Meth

Code for
Code for "ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on", accepted at WACV 2021 Generation of Human Behavior Workshop.

ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on [ Paper ] [ Project Page ] This repository contains the code fo

This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild with Dense 3D Representations and A Benchmark. (CVPR 2022)"

Gait3D-Benchmark This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild

Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

[WACV 2020] Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints

Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints Official implementation for Reducing Footskate in Human Motion Recon

Comments
  • Question about the mutual information in the agmax loss

    Question about the mutual information in the agmax loss

    The agmax loss is computed by "agreement_loss, dl = agmax_loss(y, target, self.args.dl_weight)" , I do not know what does the argeement_loss and dl mean, and do not know the relationship between these two loss and the calculation of mutual information. Could you give me some help?

    opened by YananGu 4
Owner
Rowel Atienza
Rowel Atienza
[Preprint] "Chasing Sparsity in Vision Transformers: An End-to-End Exploration" by Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang

Chasing Sparsity in Vision Transformers: An End-to-End Exploration Codes for [Preprint] Chasing Sparsity in Vision Transformers: An End-to-End Explora

VITA 64 Dec 08, 2022
auto-tuning momentum SGD optimizer

YellowFin YellowFin is an auto-tuning optimizer based on momentum SGD which requires no manual specification of learning rate and momentum. It measure

Jian Zhang 288 Nov 19, 2022
Its a Plant Leaf Disease Detection System based on Machine Learning.

My_Project_Code Its a Plant Leaf Disease Detection System based on Machine Learning. I have used Tomato Leaves Dataset from kaggle. This system detect

Sanskriti Sidola 3 Jun 15, 2022
Code to replicate the key results from Exploring the Limits of Out-of-Distribution Detection

Exploring the Limits of Out-of-Distribution Detection In this repository we're collecting replications for the key experiments in the Exploring the Li

Stanislav Fort 35 Jan 03, 2023
RATE: Overcoming Noise and Sparsity of Textual Features in Real-Time Location Estimation (CIKM'17)

RATE: Overcoming Noise and Sparsity of Textual Features in Real-Time Location Estimation This is the implementation of RATE: Overcoming Noise and Spar

Yu Zhang 5 Feb 10, 2022
​TextWorld is a sandbox learning environment for the training and evaluation of reinforcement learning (RL) agents on text-based games.

TextWorld A text-based game generator and extensible sandbox learning environment for training and testing reinforcement learning (RL) agents. Also ch

Microsoft 983 Dec 23, 2022
ROMP: Monocular, One-stage, Regression of Multiple 3D People, ICCV21

Monocular, One-stage, Regression of Multiple 3D People ROMP, accepted by ICCV 2021, is a concise one-stage network for multi-person 3D mesh recovery f

Yu Sun 937 Jan 04, 2023
noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

amos_xwang 57 Dec 04, 2022
LightHuBERT: Lightweight and Configurable Speech Representation Learning with Once-for-All Hidden-Unit BERT

LightHuBERT LightHuBERT: Lightweight and Configurable Speech Representation Learning with Once-for-All Hidden-Unit BERT | Github | Huggingface | SUPER

WangRui 46 Dec 29, 2022
Code for "Adversarial Training for a Hybrid Approach to Aspect-Based Sentiment Analysis

HAABSAStar Code for "Adversarial Training for a Hybrid Approach to Aspect-Based Sentiment Analysis". This project builds on the code from https://gith

1 Sep 14, 2020
CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

Facebook Research 721 Jan 03, 2023
Vector AI — A platform for building vector based applications. Encode, query and analyse data using vectors.

Vector AI is a framework designed to make the process of building production grade vector based applications as quickly and easily as possible. Create

Vector AI 267 Dec 23, 2022
Acoustic mosquito detection code with Bayesian Neural Networks

HumBugDB Acoustic mosquito detection with Bayesian Neural Networks. Extract audio or features from our large-scale dataset on Zenodo. This repository

31 Nov 28, 2022
TSDF++: A Multi-Object Formulation for Dynamic Object Tracking and Reconstruction

TSDF++: A Multi-Object Formulation for Dynamic Object Tracking and Reconstruction TSDF++ is a novel multi-object TSDF formulation that can encode mult

ETHZ ASL 130 Dec 29, 2022
Run Keras models in the browser, with GPU support using WebGL

**This project is no longer active. Please check out TensorFlow.js.** The Keras.js demos still work but is no longer updated. Run Keras models in the

Leon Chen 4.9k Dec 29, 2022
Fairness Metrics: All you need to know

Fairness Metrics: All you need to know Testing machine learning software for ethical bias has become a pressing current concern. Recent research has p

Anonymous2020 1 Jan 17, 2022
ReAct: Out-of-distribution Detection With Rectified Activations

ReAct: Out-of-distribution Detection With Rectified Activations This is the source code for paper ReAct: Out-of-distribution Detection With Rectified

38 Dec 05, 2022
A toolkit for document-level event extraction, containing some SOTA model implementations

❤️ A Toolkit for Document-level Event Extraction with & without Triggers Hi, there 👋 . Thanks for your stay in this repo. This project aims at buildi

Tong Zhu(朱桐) 159 Dec 22, 2022
Augmentation for Single-Image-Super-Resolution

SRAugmentation Augmentation for Single-Image-Super-Resolution Implimentation CutBlur Cutout CutMix Cutup CutMixup Blend RGBPermutation Identity OneOf

Yubo 6 Jun 27, 2022