Robust fine-tuning of zero-shot models

Related tags

Deep Learningwise-ft
Overview

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

Abstract

Large pre-trained models such as CLIP offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they also reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models. Compared to standard fine-tuning, the resulting weight-space ensembles provide large accuracy improvements out-of-distribution, while matching or improving in-distribution accuracy. On ImageNet and five derived distribution shifts, weight-space ensembles improve out-of-distribution accuracy by 2 to 10 percentage points while increasing in-distribution accuracy by nearly 1 percentage point relative to standard fine-tuning. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Compared to standard fine-tuning, weight-space ensembles for fine-tuning (WiSE-FT) improve out-of-distribution (OOD) accuracy without decreasing in-distribution (ID) performance. Top left: Zero-shot CLIP models exhibit high effective robustness and moderate in-distribution accuracy, while standard fine-tuning (end-to-end or with a linear classifier) attains higher ID accuracy and less effective robustness. Top right: Our method linearly interpolates between the zero-shot and fine-tuned models with a mixing coefficient alpha in [0,1]. Bottom: On five distribution shifts derived from ImageNet (ImageNetV2, ImageNet-R, ImageNet Sketch, ObjectNet, and ImageNet-A), WiSE-FT improves average OOD accuracy by 8.7 percentage points (pp) when fine-tuning end-to-end (+2.1 pp when fine-tuning a linear classifier) while maintaining ID accuracy.

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}
A fast python implementation of Ray Tracing in One Weekend using python and Taichi

ray-tracing-one-weekend-taichi A fast python implementation of Ray Tracing in One Weekend using python and Taichi. Taichi is a simple "Domain specific

157 Dec 26, 2022
CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY

M-BERT-Study CROSS-LINGUAL ABILITY OF MULTILINGUAL BERT: AN EMPIRICAL STUDY Motivation Multilingual BERT (M-BERT) has shown surprising cross lingual a

CogComp 1 Feb 28, 2022
DC3: A Learning Method for Optimization with Hard Constraints

DC3: A learning method for optimization with hard constraints This repository is by Priya L. Donti, David Rolnick, and J. Zico Kolter and contains the

CMU Locus Lab 57 Dec 26, 2022
Tensorflow implementation of "Learning Deep Features for Discriminative Localization"

Weakly_detector Tensorflow implementation of "Learning Deep Features for Discriminative Localization" B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and

Taeksoo Kim 363 Jun 29, 2022
Automatic Video Captioning Evaluation Metric --- EMScore

Automatic Video Captioning Evaluation Metric --- EMScore Overview For an illustration, EMScore can be computed as: Installation modify the encode_text

Yaya Shi 17 Nov 28, 2022
An End-to-End Machine Learning Library to Optimize AUC (AUROC, AUPRC).

Logo by Zhuoning Yuan LibAUC: A Machine Learning Library for AUC Optimization Website | Updates | Installation | Tutorial | Research | Github LibAUC a

Optimization for AI 176 Jan 07, 2023
Rethinking Portrait Matting with Privacy Preserving

Rethinking Portrait Matting with Privacy Preserving This is the official repository of the paper Rethinking Portrait Matting with Privacy Preserving.

184 Jan 03, 2023
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmenta

NVIDIA Research Projects 3.2k Dec 30, 2022
PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street

PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street This is

ShotaDEGUCHI 2 Apr 18, 2022
Unofficial PyTorch Implementation of Multi-Singer

Multi-Singer Unofficial PyTorch Implementation of Multi-Singer: Fast Multi-Singer Singing Voice Vocoder With A Large-Scale Corpus. Requirements See re

SunMail-hub 123 Dec 28, 2022
Set of models for classifcation of 3D volumes

Classification models 3D Zoo - Keras and TF.Keras This repository contains 3D variants of popular CNN models for classification like ResNets, DenseNet

69 Dec 28, 2022
RuleBERT: Teaching Soft Rules to Pre-Trained Language Models

RuleBERT: Teaching Soft Rules to Pre-Trained Language Models (Paper) (Slides) (Video) RuleBERT is a pre-trained language model that has been fine-tune

16 Aug 24, 2022
This repository contains the source code and data for reproducing results of Deep Continuous Clustering paper

Deep Continuous Clustering Introduction This is a Pytorch implementation of the DCC algorithms presented in the following paper (paper): Sohil Atul Sh

Sohil Shah 197 Nov 29, 2022
The code for SAG-DTA: Prediction of Drug–Target Affinity Using Self-Attention Graph Network.

SAG-DTA The code is the implementation for the paper 'SAG-DTA: Prediction of Drug–Target Affinity Using Self-Attention Graph Network'. Requirements py

Shugang Zhang 7 Aug 02, 2022
Tools for investing in Python

InvestOps Original repository on GitHub Original author is Magnus Erik Hvass Pedersen Introduction This is a Python package with simple and effective

24 Nov 26, 2022
使用深度学习框架提取视频硬字幕;docker容器免安装深度学习库,使用本地api接口使得界面和后端识别分离;

extract-video-subtittle 使用深度学习框架提取视频硬字幕; 本地识别无需联网; CPU识别速度可观; 容器提供API接口; 运行环境 本项目运行环境非常好搭建,我做好了docker容器免安装各种深度学习包; 提供windows界面操作; 容器为CPU版本; 视频演示 https

歌者 16 Aug 06, 2022
GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️

GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️ This repo contains a PyTorch implementation of the original GAT paper ( 🔗 Veličković et

Aleksa Gordić 1.9k Jan 09, 2023
Semi-supervised Learning for Sentiment Analysis

Neural-Semi-supervised-Learning-for-Text-Classification-Under-Large-Scale-Pretraining Code, models and Datasets for《Neural Semi-supervised Learning fo

47 Jan 01, 2023
Semi-supervised learning for object detection

Source code for STAC: A Simple Semi-Supervised Learning Framework for Object Detection STAC is a simple yet effective SSL framework for visual object

Google Research 348 Dec 25, 2022