Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

Attention Probe: Vision Transformer Distillation in the Wild

License: MIT

Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang
In ICASSP 2022

This code is the Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

  • We propose the concept of Attention Probe, a special section of the attention map to utilize a large amount of unlabeled data in the wild to complete the vision transformer data-free distillation task. Instead of generating images from the teacher network with a series of priori, images most relevant to the given pre-trained network and tasks will be identified from a large unlabeled dataset (e.g., Flickr) to conduct the knowledge distillation task.
  • We propose a simple yet efficient distillation algorithm, called probe distillation, to distill the student model using intermediate features of the teacher model, which is based on the Attention Probe.

Prerequisite

We use Pytorch 1.7.1, and CUDA 11.0. You can install them with

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

It should also be applicable to other Pytorch and CUDA versions.

Usage

Data Preparation

First, you need to modify the storage format of the cifar-10/100 and tinyimagenet dataset to the style of ImageNet, etc. CIFAR 10 run:

python process_cifar10.py

CIFAR 100 run:

python process_cifar100.py

Tiny-ImageNet run:

python process_tinyimagenet.py
python process_move_file.py

The dataset dir should have the following structure:

dir/
  train/
    ...
  val/
    n01440764/
      ILSVRC2012_val_00000293.JPEG
      ...
    ...

Train a normal teacher network

For this step you need to train normal teacher transformer models for selecting valuable data from the wild. We train the teacher model based on the timm PyTorch library:

timm

Our pretrained teacher models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here:

Pretrained teacher models

Select valuable data from the wild

Then, you can use the Attention Probe method to select valuable data in the wild dataset.

To select valuable data CIFAR-10 run:

bash training.sh
(CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

CIFAR-100 run:

bash training.sh
(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

TinyImageNet run:

bash training_tinyimagenet.sh

ImageNet run:

bash training_imagenet.sh

After you will get "class_weights.pth, pred_out.pth, value_blk3.pth, value_blk7.pth, value_out.pth" in '/selected/cifar10/' or '/selected/cifar100/' directory, you have already obtained the selected data.

Probe Knowledge Distillation for Student networks

Then you can distill the student model using intermediate features of the teacher model based on the selected data.

bash training.sh
(CIFAR 10 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0,1,2,3 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

TinyImageNet run:

bash training_tinyimagenet.sh

ImageNet run:

bash training_imagenet.sh

you will get the student transformer model in '/output/cifar10/student/' or '/output/cifar100/student/' directory.

Our distilled student models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here: Distilled student models

Results

Citation

@inproceedings{
wang2022attention,
title={Attention Probe: Vision Transformer Distillation in the Wild},
author={Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang},
booktitle={International Conference on Acoustics, Speech and Signal Processing},
year={2022},
url={https://2022.ieeeicassp.org/}
}

Acknowledgement

Owner
IIGROUP
The Intelligent Interaction Group at Tsinghua University
IIGROUP
Learning Calibrated-Guidance for Object Detection in Aerial Images

Learning Calibrated-Guidance for Object Detection in Aerial Images arxiv We propose a simple yet effective Calibrated-Guidance (CG) scheme to enhance

51 Sep 22, 2022
Pytorch implementation of PCT: Point Cloud Transformer

PCT: Point Cloud Transformer This is a Pytorch implementation of PCT: Point Cloud Transformer.

Yi_Zhang 265 Dec 22, 2022
This repo is for segmentation of T2 hyp regions in gliomas.

T2-Hyp-Segmentor This repo is for segmentation of T2 hyp regions in gliomas. By downloading the model from here you can use it to segment your T2w ima

1 Jan 18, 2022
This repository contains the code used for the implementation of the paper "Probabilistic Regression with HuberDistributions"

Public_prob_regression_with_huber_distributions This repository contains the code used for the implementation of the paper "Probabilistic Regression w

David Mohlin 1 Dec 04, 2021
Companion code for "Bayesian logistic regression for online recalibration and revision of risk prediction models with performance guarantees"

Companion code for "Bayesian logistic regression for online recalibration and revision of risk prediction models with performance guarantees" Installa

0 Oct 13, 2021
Official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting

1 SNAS4MTF This repo is the official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting. 1.1 The frame

SZJ 5 Sep 21, 2022
Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos.

A Latent Transformer for Disentangled Face Editing in Images and Videos Official implementation for paper: A Latent Transformer for Disentangled Face

InterDigital 108 Dec 09, 2022
A simple python stock Predictor

Python Stock Predictor A simple python stock Predictor Demo Run Locally Clone the project git clone https://github.com/yashraj-n/stock-price-predict

Yashraj narke 5 Nov 29, 2021
This repository provides a basic implementation of our GCPR 2021 paper "Learning Conditional Invariance through Cycle Consistency"

Learning Conditional Invariance through Cycle Consistency This repository provides a basic TensorFlow 1 implementation of the proposed model in our GC

BMDA - University of Basel 1 Nov 04, 2022
Implementation EfficientDet: Scalable and Efficient Object Detection in PyTorch

Implementation EfficientDet: Scalable and Efficient Object Detection in PyTorch

tonne 1.4k Dec 29, 2022
official implementation for the paper "Simplifying Graph Convolutional Networks"

Simplifying Graph Convolutional Networks Updates As pointed out by #23, there was a subtle bug in our preprocessing code for the reddit dataset. After

Tianyi 727 Jan 01, 2023
This code implements constituency parse tree aggregation

README This code implements constituency parse tree aggregation. Folder details code: This folder contains the code that implements constituency parse

Adithya Kulkarni 0 Oct 11, 2021
Wikidated : An Evolving Knowledge Graph Dataset of Wikidata’s Revision History

Wikidated Wikidated 1.0 is a dataset of Wikidata’s full revision history, which encodes changes between Wikidata revisions as sets of deletions and ad

Lukas Schmelzeisen 11 Aug 16, 2022
Replication Package for "An Empirical Study of the Effectiveness of an Ensemble of Stand-alone Sentiment Detection Tools for Software Engineering Datasets"

Replication Package for "An Empirical Study of the Effectiveness of an Ensemble of Stand-alone Sentiment Detection Tools for Software Engineering Data

2 Oct 06, 2022
TensorFlow implementation of AlexNet and its training and testing on ImageNet ILSVRC 2012 dataset

AlexNet training on ImageNet LSVRC 2012 This repository contains an implementation of AlexNet convolutional neural network and its training and testin

Matteo Dunnhofer 161 Nov 25, 2022
ClevrTex: A Texture-Rich Benchmark for Unsupervised Multi-Object Segmentation

ClevrTex This repository contains dataset generation code for ClevrTex benchmark from paper: ClevrTex: A Texture-Rich Benchmark for Unsupervised Multi

Laurynas Karazija 26 Dec 21, 2022
SpiroMask: Measuring Lung Function Using Consumer-Grade Masks

SpiroMask: Measuring Lung Function Using Consumer-Grade Masks Anonymised repository for paper submitted for peer review at ACM HEALTH (October 2021).

0 May 10, 2022
The code used for the free [email protected] Webinar series on Reinforcement Learning in Finance

Reinforcement Learning in Finance [email protected] Webinar This repository provides the code f

Yves Hilpisch 62 Dec 22, 2022
Official implementation of the PICASO: Permutation-Invariant Cascaded Attentional Set Operator

PICASO Official PyTorch implemetation for the paper PICASO:Permutation-Invariant Cascaded Attentive Set Operator. Requirements Python 3 torch = 1.0 n

Samira Zare 0 Dec 23, 2021