We present a regularized self-labeling approach to improve the generalization and robustness properties of fine-tuning.

Overview

Overview

This repository provides the implementation for the paper "Improved Regularization and Robustness for Fine-tuning in Neural Networks", which will be presented as a poster paper in NeurIPS'21.

In this work, we propose a regularized self-labeling approach that combines regularization and self-training methods for improving the generalization and robustness properties of fine-tuning. Our approach includes two components:

  • First, we encode layer-wise regularization to penalize the model weights at different layers of the neural net.
  • Second, we add self-labeling that relabels data points based on current neural net's belief and reweights data points whose confidence is low.

Requirements

To install requirements:

pip install -r requirements.txt

Data Preparation

We use seven image datasets in our paper. We list the link for downloading these datasets and describe how to prepare data to run our code below.

  • Aircrafts: download and extract into ./data/aircrafts
    • remove the class 257.clutter out of the data directory
  • CUB-200-2011: download and extract into ./data/CUB_200_2011/
  • Caltech-256: download and extract into ./data/caltech256/
  • Stanford-Cars: download and extract into ./data/StanfordCars/
  • Stanford-Dogs: download and extract into ./data/StanfordDogs/
  • Flowers: download and extract into ./data/flowers/
  • MIT-Indoor: download and extract into ./data/Indoor/

Our code automatically handles the split of the datasets.

Usage

Our algorithm (RegSL) interpolates between layer-wise regularization and self-labeling. Run the following commands for conducting experiments in this paper.

Fine-tuning ResNet-101 on image classification tasks.

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_indoor.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.136809975858091 --reg_predictor 6.40780158171339 --scale_factor 2.52883770643206\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_aircrafts.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 1.18330556653284 --reg_predictor 5.27713618808711 --scale_factor 1.27679969876201\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_birds.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.204403908747731 --reg_predictor 23.7850606577679 --scale_factor 4.73803591794678\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_caltech.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.0867998872549272 --reg_predictor 9.4552942790218 --scale_factor 1.1785989596144\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_cars.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 1.3340347414257 --reg_predictor 8.26940794089601 --scale_factor 3.47676759842434\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_dogs.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.0561320847651626 --reg_predictor 4.46281825974388 --scale_factor 1.58722606909531\
    --device 1

python train_constraint.py --model ResNet101 \
    --config configs/config_constraint_flower.json \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.131991042311165 --reg_predictor 10.7674132173309 --scale_factor 4.98010215976503\
    --device 1

Fine-tuning ResNet-18 under label noise.

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 7.80246991703043 --reg_predictor 14.077402847906 \
    --noise_rate 0.2 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 8.47139398080791 --reg_predictor 19.0191127114923 \
    --noise_rate 0.4 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 

python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 10.7576018531961 --reg_predictor 19.8157649727473 \
    --noise_rate 0.6 --train_correct_label --reweight_epoch 5 --reweight_temp 2.0 --correct_epoch 10 --correct_thres 0.9 
    
python train_label_noise.py --config configs/config_constraint_indoor.json --model ResNet18 \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 9.2031662757248 --reg_predictor 6.41568500472423 \
    --noise_rate 0.8 --train_correct_label --reweight_epoch 5 --reweight_temp 1.5 --correct_epoch 10 --correct_thres 0.9 

Fine-tuning Vision Transformer on noisy labels.

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method none --reg_norm none \
    --lr 0.0001 --device 1 --noise_rate 0.4

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method none --reg_norm none \
    --lr 0.0001 --device 1 --noise_rate 0.8

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.7488074175044196 --reg_predictor 9.842955837419588 \
    --train_correct_label --reweight_epoch 24 --correct_epoch 18\
    --lr 0.0001 --device 1 --noise_rate 0.4

python train_label_noise.py --config configs/config_constraint_indoor.json \
    --model VisionTransformer --is_vit --img_size 224 --vit_type ViT-B_16 --vit_pretrained_dir pretrained/imagenet21k_ViT-B_16.npz \
    --reg_method constraint --reg_norm frob \
    --reg_extractor 0.1568903647089986 --reg_predictor 1.407080880079702 \
    --train_correct_label --reweight_epoch 18 --correct_epoch 2\
    --lr 0.0001 --device 1 --noise_rate 0.8

Please follow the instructions in ViT-pytorch to download the pre-trained models.

Fine-tuning ResNet-18 on ChestX-ray14 data set.

Run experiments on ChestX-ray14 in reproduce-chexnet path:

cd reproduce-chexnet

python retrain.py --reg_method None --reg_norm None --device 0

python retrain.py --reg_method constraint --reg_norm frob \
    --reg_extractor 5.728564437344309 --reg_predictor 2.5669480884876905 --scale_factor 1.0340072757925474 \
    --device 0

Citation

If you find this repository useful, consider citing our work titled above.

Acknowledgment

Thanks to the authors of the following repositories for providing their implementation publicly available.

Owner
NEU-StatsML-Research
We are a group of faculty and students from the Computer Science College of Northeastern University
NEU-StatsML-Research
PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations.

HPNet This repository contains the PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations. Installation The

Siming Yan 42 Dec 07, 2022
Source code for PairNorm (ICLR 2020)

PairNorm Official pytorch source code for PairNorm paper (ICLR 2020) This code requires pytorch_geometric=1.3.2 usage For SGC, we use original PairNo

62 Dec 08, 2022
Code for the paper “The Peril of Popular Deep Learning Uncertainty Estimation Methods”

Uncertainty Estimation Methods Code for the paper “The Peril of Popular Deep Learning Uncertainty Estimation Methods” Reference If you use this code,

EPFL Machine Learning and Optimization Laboratory 4 Apr 05, 2022
This repository is an unoffical PyTorch implementation of Medical segmentation in 3D and 2D.

Pytorch Medical Segmentation Read Chinese Introduction:Here! Recent Updates 2021.1.8 The train and test codes are released. 2021.2.6 A bug in dice was

EasyCV-Ellis 618 Dec 27, 2022
Hierarchical Memory Matching Network for Video Object Segmentation (ICCV 2021)

Hierarchical Memory Matching Network for Video Object Segmentation Hongje Seong, Seoung Wug Oh, Joon-Young Lee, Seongwon Lee, Suhyeon Lee, Euntai Kim

Hongje Seong 72 Dec 14, 2022
Code for CPM-2 Pre-Train

CPM-2 Pre-Train Pre-train CPM-2 此分支为110亿非 MoE 模型的预训练代码,MoE 模型的预训练代码请切换到 moe 分支 CPM-2技术报告请参考link。 0 模型下载 请在智源资源下载页面进行申请,文件介绍如下: 文件名 描述 参数大小 100000.tar

Tsinghua AI 136 Dec 28, 2022
TensorFlow GNN is a library to build Graph Neural Networks on the TensorFlow platform.

TensorFlow GNN This is an early (alpha) release to get community feedback. It's under active development and we may break API compatibility in the fut

889 Dec 30, 2022
Super Resolution for images using deep learning.

Neural Enhance Example #1 — Old Station: view comparison in 24-bit HD, original photo CC-BY-SA @siv-athens. As seen on TV! What if you could increase

Alex J. Champandard 11.7k Dec 29, 2022
Deep-learning-roadmap - All You Need to Know About Deep Learning - A kick-starter

Deep Learning - All You Need to Know Sponsorship To support maintaining and upgrading this project, please kindly consider Sponsoring the project deve

Instill AI 4.4k Dec 26, 2022
Galileo library for large scale graph training by JD

近年来,图计算在搜索、推荐和风控等场景中获得显著的效果,但也面临超大规模异构图训练,与现有的深度学习框架Tensorflow和PyTorch结合等难题。 Galileo(伽利略)是一个图深度学习框架,具备超大规模、易使用、易扩展、高性能、双后端等优点,旨在解决超大规模图算法在工业级场景的落地难题,提

JD Galileo Team 128 Nov 29, 2022
Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Train longer, generalize better - Big batch training This is a code repository used to generate the results appearing in "Train longer, generalize bet

Elad Hoffer 145 Sep 16, 2022
source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

International Business Machines 71 Nov 15, 2022
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
Roach: End-to-End Urban Driving by Imitating a Reinforcement Learning Coach

CARLA-Roach This is the official code release of the paper End-to-End Urban Driving by Imitating a Reinforcement Learning Coach by Zhejun Zhang, Alexa

Zhejun Zhang 118 Dec 28, 2022
Summary of related papers on visual attention

This repo is built for paper: Attention Mechanisms in Computer Vision: A Survey paper Vision-Attention-Papers Channel attention Spatial attention Temp

MenghaoGuo 2.1k Dec 30, 2022
Put blind watermark into a text with python

text_blind_watermark Put blind watermark into a text. Can be used in Wechat dingding ... How to Use install pip install text_blind_watermark Alice Pu

郭飞 164 Dec 30, 2022
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data

Introduction PyTorch3D provides efficient, reusable components for 3D Computer Vision research with PyTorch. Key features include: Data structure for

Facebook Research 6.8k Jan 01, 2023
TeachMyAgent is a testbed platform for Automatic Curriculum Learning methods in Deep RL.

TeachMyAgent: a Benchmark for Automatic Curriculum Learning in Deep RL Paper Website Documentation TeachMyAgent is a testbed platform for Automatic Cu

Flowers Team 51 Dec 25, 2022
Safe Bayesian Optimization

SafeOpt - Safe Bayesian Optimization This code implements an adapted version of the safe, Bayesian optimization algorithm, SafeOpt [1], [2]. It also p

Felix Berkenkamp 111 Dec 11, 2022
Official implementation of Monocular Quasi-Dense 3D Object Tracking

Monocular Quasi-Dense 3D Object Tracking Monocular Quasi-Dense 3D Object Tracking (QD-3DT) is an online framework detects and tracks objects in 3D usi

Visual Intelligence and Systems Group 441 Dec 20, 2022