《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Related tags

Deep Learningpit
Overview

Rethinking Spatial Dimensions of Vision Transformers

Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper

NAVER AI LAB

teaser

Abstract

Vision Transformer (ViT) extends the application range of transformers from language processing to computer vision tasks as being an alternative architecture against the existing convolutional neural networks (CNN). Since the transformer-based architecture has been innovative for computer vision modeling, the design convention towards an effective architecture has been less studied yet. From the successful design principles of CNN, we investigate the role of the spatial dimension conversion and its effectiveness on the transformer-based architecture. We particularly attend the dimension reduction principle of CNNs; as the depth increases, a conventional CNN increases channel dimension and decreases spatial dimensions. We empirically show that such a spatial dimension reduction is beneficial to a transformer architecture as well, and propose a novel Pooling-based Vision Transformer (PiT) upon the original ViT model. We show that PiT achieves the improved model capability and generalization performance against ViT. Throughout the extensive experiments, we further show PiT outperforms the baseline on several tasks such as image classification, object detection and robustness evaluation.

Model performance

We compared performance of PiT with DeiT models in various training settings. Throughput (imgs/sec) values are measured in a machine with single V100 gpu with 128 batche size.

Network FLOPs # params imgs/sec Vanilla +CutMix +DeiT +Distill
DeiT-Ti 1.3 G 5.7 M 2564 68.7 68.5 72.2 74.5
PiT-Ti 0.71 G 4.9 M 3030 71.3 72.6 73.0 74.6
PiT-XS 1.4 G 10.6 M 2128 72.4 76.8 78.1 79.1
DeiT-S 4.6 G 22.1 M 980 68.7 76.5 79.8 81.2
PiT-S 2.9 G 23.5 M 1266 73.3 79.0 80.9 81.9
DeiT-B 17.6 G 86.6 M 303 69.3 75.3 81.8 83.4
PiT-B 12.5 G 73.8 M 348 76.1 79.9 82.0 84.0

Pretrained weights

Model name FLOPs accuracy weights
pit_ti 0.71 G 73.0 link
pit_xs 1.4 G 78.1 link
pit_s 2.9 G 80.9 link
pit_b 12.5 G 82.0 link
pit_ti_distilled 0.71 G 74.6 link
pit_xs_distilled 1.4 G 79.1 link
pit_s_distilled 2.9 G 81.9 link
pit_b_distilled 12.5 G 84.0 link

Dependancies

Our implementations are tested on following libraries with Python 3.6.9 and CUDA 10.1.

torch: 1.7.1
torchvision: 0.8.2
timm: 0.3.4
einops: 0.3.0

Install other dependencies using the following command.

pip install -r requirements.txt

How to use models

You can build PiT models directly

import torch
import pit

model = pit.pit_s(pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)))

Or using timm function

import torch
import timm
import pit

model = timm.create_model('pit_s', pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)))

To use models trained with distillation, you should use _distilled model and weights.

import torch
import pit

model = pit.pit_s_distilled(pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_distill_819.pth'))
print(model(torch.randn(1, 3, 224, 224)))

License

Copyright 2021-present NAVER Corp.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Citation

@article{heo2021pit,
    title={Rethinking Spatial Dimensions of Vision Transformers},
    author={Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
    journal={arXiv: 2103.16302},
    year={2021},
}
Owner
NAVER AI
Official account of NAVER AI, Korea No.1 Industrial AI Research Group
NAVER AI
DeepRec is a recommendation engine based on TensorFlow.

DeepRec Introduction DeepRec is a recommendation engine based on TensorFlow 1.15, Intel-TensorFlow and NVIDIA-TensorFlow. Background Sparse model is a

Alibaba 676 Jan 03, 2023
Keras implementation of Deeplab v3+ with pretrained weights

Keras implementation of Deeplabv3+ This repo is not longer maintained. I won't respond to issues but will merge PR DeepLab is a state-of-art deep lear

1.3k Dec 07, 2022
QR2Pass-project - A proof of concept for an alternative (passwordless) authentication system to a web server

QR2Pass This is a proof of concept for an alternative (passwordless) authenticat

4 Dec 09, 2022
Official codebase used to develop Vision Transformer, MLP-Mixer, LiT and more.

Big Vision This codebase is designed for training large-scale vision models on Cloud TPU VMs. It is based on Jax/Flax libraries, and uses tf.data and

Google Research 701 Jan 03, 2023
A PyTorch Lightning solution to training OpenAI's CLIP from scratch.

train-CLIP 📎 A PyTorch Lightning solution to training CLIP from scratch. Goal ⚽ Our aim is to create an easy to use Lightning implementation of OpenA

Cade Gordon 396 Dec 30, 2022
Mesh TensorFlow: Model Parallelism Made Easier

Mesh TensorFlow - Model Parallelism Made Easier Introduction Mesh TensorFlow (mtf) is a language for distributed deep learning, capable of specifying

1.3k Dec 26, 2022
Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022)

Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022) Please cite "Independent SE(3)-Equivar

Octavian Ganea 154 Jan 02, 2023
This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

Nils L. Westhausen 182 Jan 07, 2023
OpenMMLab 3D Human Parametric Model Toolbox and Benchmark

Introduction English | 简体中文 MMHuman3D is an open source PyTorch-based codebase for the use of 3D human parametric models in computer vision and comput

OpenMMLab 782 Jan 04, 2023
Predicting the duration of arrival delays for commercial flights.

Flight Delay Prediction Our objective is to predict arrival delays of commercial flights. According to the US Department of Transportation, about 21%

Jordan Silke 1 Jan 11, 2022
Memory-Augmented Model Predictive Control

Memory-Augmented Model Predictive Control This repository hosts the source code for the journal article "Composing MPC with LQR and Neural Networks fo

Fangyu Wu 1 Jun 19, 2022
A fast MoE impl for PyTorch

An easy-to-use and efficient system to support the Mixture of Experts (MoE) model for PyTorch.

Rick Ho 873 Jan 09, 2023
A modified version of DeepMind's Alphafold2 to divide CPU part (MSA and template searching) and GPU part (prediction model)

ParallelFold Author: Bozitao Zhong This is a modified version of DeepMind's Alphafold2 to divide CPU part (MSA and template searching) and GPU part (p

Bozitao Zhong 77 Dec 22, 2022
Volumetric parameterization of the placenta to a flattened template

placenta-flattening A MATLAB algorithm for volumetric mesh parameterization. Developed for mapping a placenta segmentation derived from an MRI image t

Mazdak Abulnaga 12 Mar 14, 2022
Supervised multi-SNE (S-multi-SNE): Multi-view visualisation and classification

S-multi-SNE Supervised multi-SNE (S-multi-SNE): Multi-view visualisation and classification A repository containing the code to reproduce the findings

Theodoulos Rodosthenous 3 Apr 15, 2022
Search and filter videos based on objects that appear in them using convolutional neural networks

Thingscoop: Utility for searching and filtering videos based on their content Description Thingscoop is a command-line utility for analyzing videos se

Anastasis Germanidis 354 Dec 04, 2022
KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control

KeypointDeformer: Unsupervised 3D Keypoint Discovery for Shape Control Tomas Jakab, Richard Tucker, Ameesh Makadia, Jiajun Wu, Noah Snavely, Angjoo Ka

Tomas Jakab 87 Nov 30, 2022
Official implementation of "Robust channel-wise illumination estimation"

This repository provides the official implementation of "Robust channel-wise illumination estimation." accepted in BMVC (2021).

Firas Laakom 4 Nov 08, 2022
Person Re-identification

Person Re-identification Final project of Computer Vision Table of content Person Re-identification Table of content Students: Proposed method Dataset

Nguyễn Hoàng Quân 4 Jun 17, 2021
Asterisk is a framework to generate high-quality training datasets at scale

Asterisk is a framework to generate high-quality training datasets at scale

Mona Nashaat 44 Apr 25, 2022