Official implement of Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer

Overview

Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer

This repository contains the PyTorch code for Evo-ViT.

This work proposes a slow-fast token evolution approach to accelerate vanilla vision transformers of both flat and deep-narrow structures without additional pre-training and fine-tuning procedures. For details please see Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer by Yifan Xu*, Zhijie Zhang*, Mengdan Zhang, Kekai Sheng, Ke Li, Weiming Dong, Liqing Zhang, Changsheng Xu, and Xing Sun. intro

Our code is based on pytorch-image-models, DeiT, and LeViT.

Preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively.

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

All distillation settings are conducted with a teacher model RegNetY-160, which is available at teacher checkpoint.

Install the requirements by running:

pip3 install -r requirements.txt

NOTE that all experiments in the paper are conducted under cuda11.0. If necessary, please install the following packages under the environment with CUDA version 11.0: torch1.7.0-cu110, torchvision-0.8.1-cu110.

Model Zoo

We provide our Evo-ViT models pretrained on ImageNet:

Name Top-1 Acc (%) Throughput (img/s) Url
Evo-ViT-T 72.0 4027 Google Drive
Evo-ViT-S 79.4 1510 Google Drive
Evo-ViT-B 81.3 462 Google Drive
Evo-LeViT-128S 73.0 10135 Google Drive
Evo-LeViT-128 74.4 8323 Google Drive
Evo-LeViT-192 76.8 6148 Google Drive
Evo-LeViT-256 78.8 4277 Google Drive
Evo-LeViT-384 80.7 2412 Google Drive
Evo-ViT-B* 82.0 139 Google Drive
Evo-LeViT-256* 81.1 1285 Google Drive
Evo-LeViT-384* 82.2 712 Google Drive

The input image resolution is 224 × 224 unless specified. * denotes the input image resolution is 384 × 384.

Usage

Evaluation

To evaluate a pre-trained model, run:

python3 main_deit.py --model evo_deit_small_patch16_224 --eval --resume /path/to/checkpoint.pth --batch-size 256 --data-path /path/to/imagenet

Training with input resolution of 224

To train Evo-ViT on ImageNet on a single node with 8 gpus for 300 epochs, run:

Evo-ViT-T

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_deit.py --model evo_deit_tiny_patch16_224 --drop-path 0 --batch-size 256 --data-path /path/to/imagenet --output_dir /path/to/save

Evo-ViT-S

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_deit.py --model evo_deit_small_patch16_224 --batch-size 128 --data-path /path/to/imagenet --output_dir /path/to/save

Sometimes loss Nan happens in the early training epochs of DeiT-B, which is described in this issue. Our solution is to reduce the batch size to 128, load a warmup checkpoint trained for 9 epochs, and train Evo-ViT for the remaining 291 epochs. To train Evo-ViT-B on ImageNet on a single node with 8 gpus for 300 epochs, run:

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_deit.py --model evo_deit_base_patch16_224 --batch-size 128 --data-path /path/to/imagenet --output_dir /path/to/save --resume /path/to/warmup_checkpoint.pth

To train Evo-LeViT-128 on ImageNet on a single node with 8 gpus for 300 epochs, run:

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_levit.py --model EvoLeViT_128 --batch-size 256 --data-path /path/to/imagenet --output_dir /path/to/save

The other models of Evo-LeViT are trained with the same command as mentioned above.

Training with input resolution of 384

To train Evo-ViT-B* on ImageNet on 2 nodes with 8 gpus each for 300 epochs, run:

python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=$NODE_SIZE  --node_rank=$NODE_RANK --master_port=$MASTER_PORT --master_addr=$MASTER_ADDR main_deit.py --model evo_deit_base_patch16_384 --input-size 384 --batch-size 64 --data-path /path/to/imagenet --output_dir /path/to/save

To train Evo-ViT-S* on ImageNet on a single node with 8 gpus for 300 epochs, run:

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_deit.py --model evo_deit_small_patch16_384 --batch-size 128 --input-size 384 --data-path /path/to/imagenet --output_dir /path/to/save"

To train Evo-LeViT-384* on ImageNet on a single node with 8 gpus for 300 epochs, run:

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_levit.py --model EvoLeViT_384_384 --input-size 384 --batch-size 128 --data-path /path/to/imagenet --output_dir /path/to/save

The other models of Evo-LeViT* are trained with the same command of Evo-LeViT-384*.

Testing inference throughput

To test inference throughput, first modify the model name in line 153 of benchmark.py. Then, run:

python3 benchmark.py

The defauld input resolution is 224. To test inference throughput with input resolution of 384, please add the parameter "--img_size 384"

Visualization of token selection

The visualization code is modified from DynamicViT.

To visualize a batch of ImageNet val images, run:

python3 visualize.py --model evo_deit_small_vis_patch16_224 --resume /path/to/checkpoint.pth --output_dir /path/to/save --data-path /path/to/imagenet --batch-size 64 

To visualize a single image, run:

python3 visualize.py --model evo_deit_small_vis_patch16_224 --resume /path/to/checkpoint.pth --output_dir /path/to/save --img-path ./imgs/a.jpg --save-name evo_test

Add parameter '--layer-wise-prune' if the visualized model is not trained with layer-to-stage training strategy.

The visualization results of Evo-ViT-S are as follows:

result

Citation

If you find our work useful in your research, please consider citing:

@article{xu2021evo,
  title={Evo-ViT: Slow-Fast Token Evolution for Dynamic Vision Transformer},
  author={Xu, Yifan and Zhang, Zhijie and Zhang, Mengdan and Sheng, Kekai and Li, Ke and Dong, Weiming and Zhang, Liqing and Xu, Changsheng and Sun, Xing},
  journal={arXiv preprint arXiv:2108.01390},
  year={2021}
}
Owner
YifanXu
But gold will glitter forever.
YifanXu
Pytorch implementation of Feature Pyramid Network (FPN) for Object Detection

fpn.pytorch Pytorch implementation of Feature Pyramid Network (FPN) for Object Detection Introduction This project inherits the property of our pytorc

Jianwei Yang 912 Dec 21, 2022
Collection of NLP model explanations and accompanying analysis tools

Thermostat is a large collection of NLP model explanations and accompanying analysis tools. Combines explainability methods from the captum library wi

126 Nov 22, 2022
Joint Versus Independent Multiview Hashing for Cross-View Retrieval[J] (IEEE TCYB 2021, PyTorch Code)

Thanks to the low storage cost and high query speed, cross-view hashing (CVH) has been successfully used for similarity search in multimedia retrieval. However, most existing CVH methods use all view

4 Nov 19, 2022
classify fashion-mnist dataset with pytorch

Fashion-Mnist Classifier with PyTorch Inference 1- clone this repository: git clone https://github.com/Jhamed7/Fashion-Mnist-Classifier.git 2- Instal

1 Jan 14, 2022
A high performance implementation of HDBSCAN clustering.

HDBSCAN HDBSCAN - Hierarchical Density-Based Spatial Clustering of Applications with Noise. Performs DBSCAN over varying epsilon values and integrates

2.3k Jan 02, 2023
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 06, 2023
CVPRW 2021: How to calibrate your event camera

E2Calib: How to Calibrate Your Event Camera This repository contains code that implements video reconstruction from event data for calibration as desc

Robotics and Perception Group 104 Nov 16, 2022
A model to classify a piece of news as REAL or FAKE

Fake_news_classification A model to classify a piece of news as REAL or FAKE. This python project of detecting fake news deals with fake and real news

Gokul Stark 1 Jan 29, 2022
Fast image augmentation library and easy to use wrapper around other libraries. Documentation: https://albumentations.ai/docs/ Paper about library: https://www.mdpi.com/2078-2489/11/2/125

Albumentations Albumentations is a Python library for image augmentation. Image augmentation is used in deep learning and computer vision tasks to inc

11.4k Jan 09, 2023
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
Multiple-criteria decision-making (MCDM) with Electre, Promethee, Weighted Sum and Pareto

EasyMCDM - Quick Installation methods Install with PyPI Once you have created your Python environment (Python 3.6+) you can simply type: pip3 install

Labrak Yanis 6 Nov 22, 2022
(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

xxxnell 656 Dec 30, 2022
Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al.

nam-pytorch Unofficial PyTorch implementation of Neural Additive Models (NAM) by Agarwal, et al. [abs, pdf] Installation You can access nam-pytorch vi

Rishabh Anand 11 Mar 14, 2022
Classification Modeling: Probability of Default

Credit Risk Modeling in Python Introduction: If you've ever applied for a credit card or loan, you know that financial firms process your information

Aktham Momani 2 Nov 07, 2022
Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch

Lie Transformer - Pytorch (wip) Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in thi

Phil Wang 78 Oct 26, 2022
This repository contains a toolkit for collecting, labeling and tracking object keypoints

This repository contains a toolkit for collecting, labeling and tracking object keypoints. Object keypoints are semantic points in an object's coordinate frame.

ETHZ ASL 13 Dec 12, 2022
A PyTorch implementation of the continual learning experiments with deep neural networks

Brain-Inspired Replay A PyTorch implementation of the continual learning experiments with deep neural networks described in the following paper: Brain

182 Dec 27, 2022
Official implementation of GraphMask as presented in our paper Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking.

GraphMask This repository contains an implementation of GraphMask, the interpretability technique for graph neural networks presented in our ICLR 2021

Michael Schlichtkrull 29 Sep 02, 2022
Pytorch implementation of OCNet series and SegFix.

openseg.pytorch News 2021/09/14 MMSegmentation has supported our ISANet and refer to ISANet for more details. 2021/08/13 We have released the implemen

openseg-group 1.1k Dec 23, 2022
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Christoph Reich 23 Sep 21, 2022