Code for "Searching for Efficient Multi-Stage Vision Transformers"

Overview

Searching for Efficient Multi-Stage Vision Transformers

This repository contains the official Pytorch implementation of "Searching for Efficient Multi-Stage Vision Transformers" and is based on DeiT and timm.

photo not available

Illustration of the proposed multi-stage ViT-Res network.


photo not available

Illustration of weight-sharing neural architecture search with multi-architectural sampling.


photo not available

Accuracy-MACs trade-offs of the proposed ViT-ResNAS. Our networks achieves comparable results to previous work.

Content

  1. Requirements
  2. Data Preparation
  3. Pre-Trained Models
  4. Training ViT-Res
  5. Performing Neural Architecture Search
  6. Evaluation

Requirements

The codebase is tested with 8 V100 (16GB) GPUs.

To install requirements:

    pip install -r requirements.txt

Docker files are provided to set up the environment. Please run:

    cd docker

    sh 1_env_setup.sh
    
    sh 2_build_docker_image.sh
    
    sh 3_run_docker_image.sh

Make sure that the configuration specified in 3_run_docker_image.sh is correct before running the command.

Data Preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Pre-Trained Models

Pre-trained weights of super-networks and searched networks can be found here.

Training ViT-Res

To train ViT-Res-Tiny, modify IMAGENET_PATH in scripts/vit-sr-nas/reference_net/tiny.sh and run:

    sh scripts/vit-sr-nas/reference_net/tiny.sh 

We use 8 GPUs for training. Please modify numbers of GPUs (--nproc_per_node) and adjust batch size (--batch-size) if different numbers of GPUs are used.

Performing Neural Architecture Search

0. Building Sub-Train and Sub-Val Set

Modify _SOURCE_DIR, _SUB_TRAIN_DIR, and _SUB_VAL_DIR in search_utils/build_subset.py, and run:

    cd search_utils
    
    python build_subset.py
    
    cd ..

1. Super-Network Training

Before running each script, modify IMAGENET_PATH (directed to the directory containing the sub-train and sub-val sets).

For ViT-ResNAS-Tiny, run:

    sh scripts/vit-sr-nas/super_net/tiny.sh

For ViT-ResNAS-Small and Medium, run:

    sh scripts/vit-sr-nas/super_net/small.sh

2. Evolutionary Search

Before running each script, modify IMAGENET_PATH (directed to the directory containing the sub-train and sub-val sets) and MODEL_PATH.

For ViT-ResNAS-Tiny, run:

    sh scripts/vit-sr-nas/evolutionary_search/tiny.sh

For ViT-ResNAS-Small, run:

    sh scripts/vit-sr-nas/evolutionary_search/[email protected]

For ViT-ResNAS-Medium, run:

    sh scripts/vit-sr-nas/evolutionary_search/[email protected]

After running evolutionary search for each network, see summary.txt in output directory and modify network_def.

For example, the network_def in summary.txt is ((4, 220), (1, (220, 5, 32), (220, 880), 1), (1, (220, 5, 32), (220, 880), 1), (1, (220, 7, 32), (220, 800), 1), (1, (220, 7, 32), (220, 800), 0), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (3, 220, 440), (1, (440, 10, 48), (440, 1760), 1), (1, (440, 10, 48), (440, 1440), 1), (1, (440, 10, 48), (440, 1920), 1), (1, (440, 10, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1120), 0), (1, (440, 12, 48), (440, 1440), 1), (3, 440, 880), (1, (880, 16, 64), (880, 3200), 1), (1, (880, 12, 64), (880, 3200), 1), (1, (880, 16, 64), (880, 2880), 1), (1, (880, 12, 64), (880, 3200), 0), (1, (880, 12, 64), (880, 2240), 1), (1, (880, 12, 64), (880, 3520), 0), (1, (880, 14, 64), (880, 2560), 1), (2, 880, 1000)).

Remove the element in the tuple that has 1 in the first element and 0 in the last element (e.g. (1, (220, 5, 32), (220, 880), 0)).

This reflects that the transformer block is removed in a searched network.

After this modification, the network_def becomes ((4, 220), (1, (220, 5, 32), (220, 880), 1), (1, (220, 5, 32), (220, 880), 1), (1, (220, 7, 32), (220, 800), 1), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (1, (220, 5, 32), (220, 720), 1), (3, 220, 440), (1, (440, 10, 48), (440, 1760), 1), (1, (440, 10, 48), (440, 1440), 1), (1, (440, 10, 48), (440, 1920), 1), (1, (440, 10, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1600), 1), (1, (440, 12, 48), (440, 1440), 1), (3, 440, 880), (1, (880, 16, 64), (880, 3200), 1), (1, (880, 12, 64), (880, 3200), 1), (1, (880, 16, 64), (880, 2880), 1), (1, (880, 12, 64), (880, 2240), 1), (1, (880, 14, 64), (880, 2560), 1), (2, 880, 1000)).

Then, use the searched network_def for searched network training.

3. Searched Network Training

Before running each script, modify IMAGENET_PATH.

For ViT-ResNAS-Tiny, run:

    sh scripts/vit-sr-nas/searched_net/tiny.sh

For ViT-ResNAS-Small, run:

    sh scripts/vit-sr-nas/searched_net/[email protected]

For ViT-ResNAS-Medium, run:

    sh scripts/vit-sr-nas/searched_net/[email protected]

4. Fine-tuning Trained Networks at Higher Resolution

Before running, modify IMAGENET_PATH and FINETUNE_PATH (directed to trained ViT-ResNAS-Medium checkpoint). Then, run:

    sh scripts/vit-sr-nas/finetune/[email protected]

To fine-tune at different resolutions, modify --model, --input-size and --mix-patch-len. We provide models at resolutions 280, 336, and 392 as shown in here. Note that --input-size must be equal to "56 * --mix-patch-len" since the spatial size in ViT-ResNAS is reduced by 56X.

Evaluation

Before running, modify IMAGENET_PATH and MODEL_PATH. Then, run:

    sh scripts/vit-sr-nas/eval/[email protected]

Questions

Please direct questions to Yi-Lun Liao ([email protected]).

License

This repository is released under the CC-BY-NC 4.0. license as found in the LICENSE file.

Owner
Yi-Lun Liao
Yi-Lun Liao
Learning to Prompt for Vision-Language Models.

CoOp Paper: Learning to Prompt for Vision-Language Models Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu CoOp (Context Optimization)

Kaiyang 679 Jan 04, 2023
Image Completion with Deep Learning in TensorFlow

Image Completion with Deep Learning in TensorFlow See my blog post for more details and usage instructions. This repository implements Raymond Yeh and

Brandon Amos 1.3k Dec 23, 2022
Deep Learning: Architectures & Methods Project: Deep Learning for Audio Super-Resolution

Deep Learning: Architectures & Methods Project: Deep Learning for Audio Super-Resolution Figure: Example visualization of the method and baseline as a

Oliver Hahn 16 Dec 23, 2022
A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

張致強 14 Dec 02, 2022
Ejemplo Algoritmo Viterbi - Example of a Viterbi algorithm applied to a hidden Markov model on DNA sequence

Ejemplo Algoritmo Viterbi Ejemplo de un algoritmo Viterbi aplicado a modelo ocul

Mateo Velásquez Molina 1 Jan 10, 2022
scikit-learn: machine learning in Python

scikit-learn is a Python module for machine learning built on top of SciPy and is distributed under the 3-Clause BSD license. The project was started

scikit-learn 52.5k Jan 08, 2023
PyTorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision.

PyTorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{CV2018, author = {Donny You ( Donny You 40 Sep 14, 2022

Find-Lane-Line - Use openCV library and Python to detect the road-lane-line

Find-Lane-Line This project is to use openCV library and Python to detect the road-lane-line. Data Pipeline Step one : Color Selection Step two : Cann

Kenny Cheng 3 Aug 17, 2022
Pytorch based library to rank predicted bounding boxes using text/image user's prompts.

pytorch_clip_bbox: Implementation of the CLIP guided bbox ranking for Object Detection. Pytorch based library to rank predicted bounding boxes using t

Sergei Belousov 50 Nov 27, 2022
A blender add-on that automatically re-aligns wrong axis objects.

Auto Align A blender add-on that automatically re-aligns wrong axis objects. Usage There are three options available in the 3D Viewport Sidebar It

29 Nov 25, 2022
A PyTorch-based open-source framework that provides methods for improving the weakly annotated data and allows researchers to efficiently develop and compare their own methods.

Knodle (Knowledge-supervised Deep Learning Framework) - a new framework for weak supervision with neural networks. It provides a modularization for se

93 Nov 06, 2022
Active window border replacement for window managers.

xborder Active window border replacement for window managers. Usage git clone https://github.com/deter0/xborder cd xborder chmod +x xborders ./xborder

deter 250 Dec 30, 2022
Research code for Arxiv paper "Camera Motion Agnostic 3D Human Pose Estimation"

GMR(Camera Motion Agnostic 3D Human Pose Estimation) This repo provides the source code of our arXiv paper: Seong Hyun Kim, Sunwon Jeong, Sungbum Park

Seong Hyun Kim 1 Feb 07, 2022
Generate pixel-style avatars with python.

face2pixel Generate pixel-style avatars with python. Run: Clone the project: git clone https://github.com/theodorecooper/face2pixel install requiremen

Theodore Cooper 2 May 11, 2022
mPose3D, a mmWave-based 3D human pose estimation model.

mPose3D, a mmWave-based 3D human pose estimation model.

KylinChen 35 Nov 08, 2022
WaveFake: A Data Set to Facilitate Audio DeepFake Detection

WaveFake: A Data Set to Facilitate Audio DeepFake Detection This is the code repository for our NeurIPS 2021 (Track on Datasets and Benchmarks) paper

Chair for Sys­tems Se­cu­ri­ty 27 Dec 22, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior. The code will release soon. Implementation Python3 PyTorch=1.0 NVIDIA GPU+

FengZhang 34 Dec 04, 2022
Predict bus arrival time using VertexAI and Nvidia's Jetson Nano

bus_prediction predict bus arrival time using VertexAI and Nvidia's Jetson Nano imagenet the command for imagenet.py look like this python3 /path/to/i

10 Dec 22, 2022
pytorch implementation for PointNet

PointNet.pytorch This repo is implementation for PointNet in pytorch. The model is in pointnet/model.py. It is teste

Fei Xia 1.7k Dec 30, 2022
Implementation of the ICCV'21 paper Temporally-Coherent Surface Reconstruction via Metric-Consistent Atlases

Temporally-Coherent Surface Reconstruction via Metric-Consistent Atlases [Papers 1, 2][Project page] [Video] The implementation of the papers Temporal

56 Nov 21, 2022