Aggragrating Nested Transformer Official Jax Implementation

Overview

Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.

This is not an officially supported Google product.

Pretrained Models and Results

Model Accuracy Checkpoint path
Nest-B 83.8 gs://gresearch/nest-checkpoints/nest-b_imagenet
Nest-S 83.3 gs://gresearch/nest-checkpoints/nest-s_imagenet
Nest-T 81.5 gs://gresearch/nest-checkpoints/nest-t_imagenet

Note: Accuracy is evaluated on the ImageNet2012 validation set.

Tensorbord.dev

See ImageNet training logs at Tensorboard.dev.

Colab

Colab is available for test: https://colab.sandbox.google.com/github/google-research/nested-transformer/blob/main/colab.ipynb

Instruction on Image Classification

Environment setup

virtualenv -p python3 --system-site-packages nestenv
source nestenv/bin/activate

pip install -r requirements.txt

Evaluate on ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from command lines. Optionally, download all pre-trained checkpoints

bash ./checkpoints/download_checkpoints.sh

Run the evaluation script to evaluate NesT-B.

python main.py --config configs/imagenet_nest.py --config.eval_only=True \
  --config.init_checkpoint="./checkpoints/nest-b_imagenet/ckpt.39" \
  --workdir="./checkpoints/nest-t_imagenet_eval"

Train on ImageNet

The default configuration trains NesT-B on TPUv2 8x8 with per device batch size 16.

python main.py --config configs/imagenet_nest.py --jax_backend_target=<TPU_IP_ADDRESS> --jax_xla_backend="tpu_driver" --workdir="./checkpoints/nest-b_imagenet"

Note: See jax/cloud_tpu_colab for info about TPU_IP_ADDRESS.

Train NesT-T on 8 GPUs.

python main.py --config configs/imagenet_nest_tiny.py --workdir="./checkpoints/nest-t_imagenet_8gpu"

The codebase does not support multi-node GPU training (>8 GPUs). The models reported in our paper is trained using TPU with 1024 total batch size.

Train on CIFAR

# Recommend to train on 2 GPUs. Training NesT-T can use 1 GPU.
CUDA_VISIBLE_DEVICES=0,1 python  main.py --config configs/cifar_nest.py --workdir="./checkpoints/nest_cifar"

Cite

@inproceedings{zhang2021aggregating,
  title={Aggregating Nested Transformers},
  author={Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2105.12723},
  year={2021}
}
Owner
Google Research
Google Research
Official repository of the paper Learning to Regress 3D Face Shape and Expression from an Image without 3D Supervision

Official repository of the paper Learning to Regress 3D Face Shape and Expression from an Image without 3D Supervision

Soubhik Sanyal 689 Dec 25, 2022
A Human-in-the-Loop workflow for creating HD images from text

A Human-in-the-Loop? workflow for creating HD images from text DALL·E Flow is an interactive workflow for generating high-definition images from text

Jina AI 2.5k Jan 02, 2023
Learning cell communication from spatial graphs of cells

ncem Features Repository for the manuscript Fischer, D. S., Schaar, A. C. and Theis, F. Learning cell communication from spatial graphs of cells. 2021

Theis Lab 77 Dec 30, 2022
Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline

vqvae_dwt_distiller.pytorch Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline. It allows to generate 512x512 ima

Sergei Belousov 25 Jul 19, 2022
Code in conjunction with the publication 'Contrastive Representation Learning for Hand Shape Estimation'

HanCo Dataset & Contrastive Representation Learning for Hand Shape Estimation Code in conjunction with the publication: Contrastive Representation Lea

Computer Vision Group, Albert-Ludwigs-Universität Freiburg 38 Dec 13, 2022
A Free and Open Source Python Library for Multiobjective Optimization

Platypus What is Platypus? Platypus is a framework for evolutionary computing in Python with a focus on multiobjective evolutionary algorithms (MOEAs)

Project Platypus 424 Dec 18, 2022
Stochastic Scene-Aware Motion Prediction

Stochastic Scene-Aware Motion Prediction [Project Page] [Paper] Description This repository contains the training code for MotionNet and GoalNet of SA

Mohamed Hassan 31 Dec 09, 2022
Course on computational design, non-linear optimization, and dynamics of soft systems at UIUC.

Computational Design and Dynamics of Soft Systems · This is a repository that contains the source code for generating the lecture notes, handouts, exe

Tejaswin Parthasarathy 4 Jul 21, 2022
Generating Band-Limited Adversarial Surfaces Using Neural Networks

Generating Band-Limited Adversarial Surfaces Using Neural Networks This is the official repository of the technical report that was published on arXiv

3 Jul 26, 2022
Code of TIP2021 Paper《SFace: Sigmoid-Constrained Hypersphere Loss for Robust Face Recognition》. We provide both MxNet and Pytorch versions.

SFace Code of TIP2021 Paper 《SFace: Sigmoid-Constrained Hypersphere Loss for Robust Face Recognition》. We provide both MxNet, PyTorch and Jittor versi

Zhong Yaoyao 47 Nov 25, 2022
Dense matching library based on PyTorch

Dense Matching A general dense matching library based on PyTorch. For any questions, issues or recommendations, please contact Prune at

Prune Truong 399 Dec 28, 2022
Multiple-Object Tracking with Transformer

TransTrack: Multiple-Object Tracking with Transformer Introduction TransTrack: Multiple-Object Tracking with Transformer Models Training data Training

Peize Sun 537 Jan 04, 2023
LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant Self-At

OxCSML (Oxford Computational Statistics and Machine Learning) 50 Dec 28, 2022
Official PyTorch implementation of paper: Standardized Max Logits: A Simple yet Effective Approach for Identifying Unexpected Road Obstacles in Urban-Scene Segmentation (ICCV 2021 Oral Presentation)

SML (ICCV 2021, Oral) : Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Standardi

SangHun 61 Dec 27, 2022
Magic tool for managing internet connection in local network by @zalexdev

Megacut ✂️ A new powerful Python3 tool for managing internet on a local network Installation git clone https://github.com/stryker-project/megacut cd m

Stryker 12 Dec 15, 2022
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

Wenhao Wang 89 Jan 02, 2023
Deep Learning ❤️ OneFlow

Deep Learning with OneFlow made easy 🚀 ! Carefree? carefree-learn aims to provide CAREFREE usages for both users and developers. User Side Computer V

21 Oct 27, 2022
A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

ICT.MIRACLE lab 75 Dec 26, 2022
Some toy examples of score matching algorithms written in PyTorch

toy_gradlogp This repo implements some toy examples of the following score matching algorithms in PyTorch: ssm-vr: sliced score matching with variance

Ending Hsiao 21 Dec 26, 2022
[ICLR 2021] "CPT: Efficient Deep Neural Network Training via Cyclic Precision" by Yonggan Fu, Han Guo, Meng Li, Xin Yang, Yining Ding, Vikas Chandra, Yingyan Lin

CPT: Efficient Deep Neural Network Training via Cyclic Precision Yonggan Fu, Han Guo, Meng Li, Xin Yang, Yining Ding, Vikas Chandra, Yingyan Lin Accep

26 Oct 25, 2022