Aggragrating Nested Transformer Official Jax Implementation

Overview

Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.

This is not an officially supported Google product.

Pretrained Models and Results

Model Accuracy Checkpoint path
Nest-B 83.8 gs://gresearch/nest-checkpoints/nest-b_imagenet
Nest-S 83.3 gs://gresearch/nest-checkpoints/nest-s_imagenet
Nest-T 81.5 gs://gresearch/nest-checkpoints/nest-t_imagenet

Note: Accuracy is evaluated on the ImageNet2012 validation set.

Tensorbord.dev

See ImageNet training logs at Tensorboard.dev.

Colab

Colab is available for test: https://colab.sandbox.google.com/github/google-research/nested-transformer/blob/main/colab.ipynb

Instruction on Image Classification

Environment setup

virtualenv -p python3 --system-site-packages nestenv
source nestenv/bin/activate

pip install -r requirements.txt

Evaluate on ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from command lines. Optionally, download all pre-trained checkpoints

bash ./checkpoints/download_checkpoints.sh

Run the evaluation script to evaluate NesT-B.

python main.py --config configs/imagenet_nest.py --config.eval_only=True \
  --config.init_checkpoint="./checkpoints/nest-b_imagenet/ckpt.39" \
  --workdir="./checkpoints/nest-t_imagenet_eval"

Train on ImageNet

The default configuration trains NesT-B on TPUv2 8x8 with per device batch size 16.

python main.py --config configs/imagenet_nest.py --jax_backend_target=<TPU_IP_ADDRESS> --jax_xla_backend="tpu_driver" --workdir="./checkpoints/nest-b_imagenet"

Note: See jax/cloud_tpu_colab for info about TPU_IP_ADDRESS.

Train NesT-T on 8 GPUs.

python main.py --config configs/imagenet_nest_tiny.py --workdir="./checkpoints/nest-t_imagenet_8gpu"

The codebase does not support multi-node GPU training (>8 GPUs). The models reported in our paper is trained using TPU with 1024 total batch size.

Train on CIFAR

# Recommend to train on 2 GPUs. Training NesT-T can use 1 GPU.
CUDA_VISIBLE_DEVICES=0,1 python  main.py --config configs/cifar_nest.py --workdir="./checkpoints/nest_cifar"

Cite

@inproceedings{zhang2021aggregating,
  title={Aggregating Nested Transformers},
  author={Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2105.12723},
  year={2021}
}
Owner
Google Research
Google Research
TC-GNN with Pytorch integration

TC-GNN (Running Sparse GNN on Dense Tensor Core on Ampere GPU) Cite this project and paper. @inproceedings{TC-GNN, title={TC-GNN: Accelerating Spars

YUKE WANG 19 Dec 01, 2022
My published benchmark for a Kaggle Simulations Competition

Lux AI Working Title Bot Please refer to the Kaggle notebook for the comment section. The comment section contains my explanation on my code structure

Tong Hui Kang 29 Aug 22, 2022
The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track.

ISC21-Descriptor-Track-1st The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track. You can check our solution

lyakaap 75 Jan 08, 2023
Official implement of "CAT: Cross Attention in Vision Transformer".

CAT: Cross Attention in Vision Transformer This is official implement of "CAT: Cross Attention in Vision Transformer". Abstract Since Transformer has

100 Dec 15, 2022
Code for Mining the Benefits of Two-stage and One-stage HOI Detection

Status: Archive (code is provided as-is, no updates expected) PPO-EWMA [Paper] This is code for training agents using PPO-EWMA and PPG-EWMA, introduce

OpenAI 33 Dec 15, 2022
🔥 Real-time Super Resolution enhancement (4x) with content loss and relativistic adversarial optimization 🔥

🔥 Real-time Super Resolution enhancement (4x) with content loss and relativistic adversarial optimization 🔥

Rishik Mourya 48 Dec 20, 2022
Code for the AAAI-2022 paper: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification (AAAI 2022) Prerequisite PyTorch = 1.2.0 P

16 Dec 14, 2022
Official code for "Towards An End-to-End Framework for Flow-Guided Video Inpainting" (CVPR2022)

E2FGVI (CVPR 2022) English | 简体中文 This repository contains the official implementation of the following paper: Towards An End-to-End Framework for Flo

Media Computing Group @ Nankai University 537 Jan 07, 2023
Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving

GSAN Introduction Code for paper GSAN: Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving, wh

YE Luyao 6 Oct 27, 2022
Optical Character Recognition + Instance Segmentation for russian and english languages

Распознавание рукописного текста в школьных тетрадях Соревнование, проводимое в рамках олимпиады НТО, разработанное Сбером. Платформа ODS. Результаты

Gerasimov Maxim 21 Dec 19, 2022
Any-to-any voice conversion using synthetic specific-speaker speeches as intermedium features

MediumVC MediumVC is an utterance-level method towards any-to-any VC. Before that, we propose SingleVC to perform A2O tasks(Xi → Ŷi) , Xi means utter

谷下雨 47 Dec 25, 2022
An Inverse Kinematics library aiming performance and modularity

IKPy Demo Live demos of what IKPy can do (click on the image below to see the video): Also, a presentation of IKPy: Presentation. Features With IKPy,

Pierre Manceron 481 Jan 02, 2023
The Rich Get Richer: Disparate Impact of Semi-Supervised Learning

The Rich Get Richer: Disparate Impact of Semi-Supervised Learning Preprocess file of the dataset used in implicit sub-populations: (Demographic groups

<a href=[email protected]"> 4 Oct 14, 2022
Plug and play transformer you can find network structure and official complete code by clicking List

Plug-and-play Module Plug and play transformer you can find network structure and official complete code by clicking List The following is to quickly

8 Mar 27, 2022
CTC segmentation python package

CTC segmentation CTC segmentation can be used to find utterances alignments within large audio files. This repository contains the ctc-segmentation py

Ludwig Kürzinger 217 Jan 04, 2023
the official code for ICRA 2021 Paper: "Multimodal Scale Consistency and Awareness for Monocular Self-Supervised Depth Estimation"

G2S This is the official code for ICRA 2021 Paper: Multimodal Scale Consistency and Awareness for Monocular Self-Supervised Depth Estimation by Hemang

NeurAI 4 Jul 27, 2022
Build a medical knowledge graph based on Unified Language Medical System (UMLS)

UMLS-Graph Build a medical knowledge graph based on Unified Language Medical System (UMLS) Requisite Install MySQL Server 5.6 and import UMLS data int

Donghua Chen 6 Dec 25, 2022
A Python implementation of global optimization with gaussian processes.

Bayesian Optimization Pure Python implementation of bayesian global optimization with gaussian processes. PyPI (pip): $ pip install bayesian-optimizat

fernando 6.5k Jan 02, 2023
Theory-inspired Parameter Control Benchmarks for Dynamic Algorithm Configuration

This repo is for the paper: Theory-inspired Parameter Control Benchmarks for Dynamic Algorithm Configuration The DAC environment is based on the Dynam

Carola Doerr 1 Aug 19, 2022
Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)'

SCL Introduction Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)' We evaluated our approach using two baseline

34 Oct 08, 2022