Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Overview

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv)

This is a Pytorch implementation of our technical report.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Training Pipeline

Pipeline

Our codes are based on the pytorch-image-models by Ross Wightman.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-L 24 768 448 150.47M 86.2 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map here. As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Coming soon

Reference

If you use this repo or find it useful, please consider citing:

@misc{jiang2021token,
      title={Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet}, 
      author={Zihang Jiang and Qibin Hou and Li Yuan and Daquan Zhou and Xiaojie Jin and Anran Wang and Jiashi Feng},
      year={2021},
      eprint={2104.10858},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Related projects

T2T-ViT, Re-labeling ImageNet.

Comments
  • error: download the pretrained model but couldn't be unzipped

    error: download the pretrained model but couldn't be unzipped

    tar -xvf lvvit_s-26M-384-84-4.pth.tar tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    opened by Williamlizl 10
  • The accuracy of the validation set is 0,and the loss is always around 13

    The accuracy of the validation set is 0,and the loss is always around 13

    Hello! I use ILSVRC2012_img_train and ILSVRC2012_img_val, and use the provided label_top5_train_nfnet from Google Drive. I train lv-vit-s with batch_size 64 without apex for one epoch. Thanks for your advice.

    opened by yifanQi98 7
  • Pretrained weights for LV-ViT-T

    Pretrained weights for LV-ViT-T

    Hi,

    Thanks for sharing your work. Could you also provide the pre-trained weights for the LV-ViT-T model variant, the one that achieves 79.1% top1-acc. as mentioned in Table 1 of your paper?

    All the best, Marc

    opened by marc345 5
  • train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    Hi, thanks for you great work. When I train script, some error occurs: AttributeError: 'tuple' object has no attribute 'log_softmax'

    with amp_autocast():   
                output = model(input)  
                loss = loss_fn(output, target)  # error occurs
    
    

    and loss function is train_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.0).cuda()

    by the way: Could you please tell me why we need to specify smoothing=0.0?

    opened by lxy5513 5
  • RuntimeError: CUDA error: device-side assert triggered

    RuntimeError: CUDA error: device-side assert triggered

    I am a green hand of DL. When I run the code of volo with tlt in a single or multi GPU, I get an error as follows: /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. Traceback (most recent call last): File "main.py", line 949, in main() File "main.py", line 664, in main optimizers=optimizers) File "main.py", line 773, in train_one_epoch label_size=args.token_label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 90, in mixup_target y1 = get_labelmaps_with_coords(target, num_classes, on_value=on_value, off_value=off_value, device=device, label_size=label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 64, in get_labelmaps_with_coords num_classes=num_classes,device=device) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 16, in get_featuremaps _label_topk[1][:, :, :].long(), RuntimeError: CUDA error: device-side assert triggered.

    I can't fix this problem right now.

    opened by JIAOJIAYUASD 4
  • Generating label for custom dataset

    Generating label for custom dataset

    Hello,

    Thank you for sharing your work. I am currently trying to generate token label to a custom dataset for model lvvit_s, but I keep getting the loss close to 7 and the Accuracy 0 (not pre-trained and using 1 GPU in Google Colab). I also tried using the pre-trained model with --transfer but got 0 in both Loss and Acc . What option should I use for a custom dataset? image

    opened by AleMaiaF 2
  • generate_label.py unable to find model lvvit_s

    generate_label.py unable to find model lvvit_s

    Hi,

    When I tried to run the label generation script for the model lvvit_s it returned an error "RuntimeError: Unknown model".

    Solution: It worked when I added the line "import tlt.models" in the file generate_label.py.

    opened by AleMaiaF 2
  • Can Token labeling reach higher than annotator model?

    Can Token labeling reach higher than annotator model?

    Greetings,

    Thank you for this incredible research.

    I would like to know if it is possible to use Token Labeling to achieve scores higher than that of the annotator model, I believe this was the case with VOLO D5 model where it achieved higher score than NFNet, model used for annotation.

    opened by ErenBalatkan 1
  • label_map does not do the same augmentation (random crop) as the input image

    label_map does not do the same augmentation (random crop) as the input image

    Hi Thanks so much for the nice work! I am curious if you could share the insight on processing of the label_map. If I understand it correctly, after we load image and the corresponding, we shall do the same cropping/ flip/ resize, but in https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/label_transforms_factory.py#L58-L73 Seems only image was cropped, but the label map does not do the same cropping, which make the label map not match with the image?

    Shall we do

            return torchvision_F.resized_crop(
                    img, i, j, h, w, self.size, interpolation
            ), torchvision_F.resized_crop(
                    label_map, i / ratio, j / ratio, h / ratio, w / ratio, self.size, interpolation
            )
    

    Thanks

    opened by haooooooqi 1
  • Python3.6, ok; Python3.8, error

    Python3.6, ok; Python3.8, error

    Test: [ 0/1] Time: 11.293 (11.293) Loss: 0.7043 (0.7043) [email protected]: 42.1875 (42.1875) [email protected]: 100.0000 (100.0000) Test: [ 1/1] Time: 0.108 (5.701) Loss: 0.5847 (0.6689) [email protected]: 89.8148 (56.3187) [email protected]: 100.0000 (100.0000) free(): invalid pointer free(): invalid pointer Traceback (most recent call last): File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 303, in <module> main() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 294, in main raise subprocess.CalledProcessError(returncode=process.returncode, subprocess.CalledProcessError: Command '['/opt/conda/bin/python3.8', '-u', 'main.py', '--local_rank=1', './dataset/c/c', '--model', 'lvvit_s', '-b', '128', '--apex-amp', '--img-size', '224', '--drop-path', '0.1', '--token-label', '--token-label-size', '14', '--dense-weight', '0.0', '--num-classes', '2', '--finetune', './pretrained/lvvit_s-26M-384-84-4.pth.tar']' died with <Signals.SIGABRT: 6>. [email protected]:/puxin_libochao/TokenLabeling# CUDA_VISIBLE_DEVICES=0,1 bash ./distributed_train.sh 2 ./dataset/c/c --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes 2 --finetune ./pretrained/lvvit_s-26M-384-84-4.pth.tar

    opened by Williamlizl 1
  • A Bag of Training Techniques for ViT

    A Bag of Training Techniques for ViT

    Hi, thanks for your wonderful work. I have a question that whether training techniques mentioned in the LV-Vit can be used in other downstream task like object detection? In your paper, I see that many of this techniques are used in ImageNet. Thanks!

    opened by qdd1234 1
  • how to apply token labeling to CNN ?

    how to apply token labeling to CNN ?

    Hello ~ I'm interested in your token labeling technique, So I want to apply this technique in CNN based model because ViT is very heavy to train.

    can I get the your code with CNN token labeling? if you're not give me some detail for implementing

    thank you.

    opened by HoJ00n2 0
  • Model settings for Cifar10

    Model settings for Cifar10

    I am interested if there is any LV-ViT- model setup you have tested for Cifar10. I would like to know the proper setup of all blocks in none pretrained weights settings.

    opened by Aminullah6264 0
Owner
蒋子航
Now a Ph.D. student supervised by Prof. Feng Jiashi in ECE, NUS.
蒋子航
Using deep actor-critic model to learn best strategies in pair trading

Deep-Reinforcement-Learning-in-Stock-Trading Using deep actor-critic model to learn best strategies in pair trading Abstract Partially observed Markov

281 Dec 09, 2022
Viewmaker Networks: Learning Views for Unsupervised Representation Learning

Viewmaker Networks: Learning Views for Unsupervised Representation Learning Alex Tamkin, Mike Wu, and Noah Goodman Paper link: https://arxiv.org/abs/2

Alex Tamkin 31 Dec 01, 2022
Colab notebook for openai/glide-text2im.

GLIDE text2im on Colab This repository provides a Colab notebook to produce images conditioned on text prompts with GLIDE [1]. Usage Run text2im.ipynb

Wok 19 Oct 19, 2022
Face2webtoon - Despite its importance, there are few previous works applying I2I translation to webtoon.

Despite its importance, there are few previous works applying I2I translation to webtoon. I collected dataset from naver webtoon 연애혁명 and tried to transfer human faces to webtoon domain.

이상윤 64 Oct 19, 2022
ConformalLayers: A non-linear sequential neural network with associative layers

ConformalLayers: A non-linear sequential neural network with associative layers ConformalLayers is a conformal embedding of sequential layers of Convo

Prograf-UFF 5 Sep 28, 2022
Learning Continuous Signed Distance Functions for Shape Representation

DeepSDF This is an implementation of the CVPR '19 paper "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" by Park et a

Meta Research 1.1k Jan 01, 2023
This repo in the implementation of EMNLP'21 paper "SPARQLing Database Queries from Intermediate Question Decompositions" by Irina Saparina, Anton Osokin

SPARQLing Database Queries from Intermediate Question Decompositions This repo is the implementation of the following paper: SPARQLing Database Querie

Yandex Research 20 Dec 19, 2022
Torchlight2 lan game server tool - A message forwarding tool for Torchlight 2 lan game

Torchlight 2 Lan Game Server Tool A message forwarding tool for Torchlight 2 lan

Huaijun Jiang 3 Nov 01, 2022
TANL: Structured Prediction as Translation between Augmented Natural Languages

TANL: Structured Prediction as Translation between Augmented Natural Languages Code for the paper "Structured Prediction as Translation between Augmen

98 Dec 15, 2022
Tool for installing and updating MiSTer cores and other files

MiSTer Downloader This tool installs and updates all the cores and other extra files for your MiSTer. It also updates the menu core, the MiSTer firmwa

72 Dec 24, 2022
SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals

SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals Abstract Sleep apnea (SA) is a common slee

9 Dec 21, 2022
A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

A pure PyTorch batched computation implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"

張致強 14 Dec 02, 2022
Pytorch implementation of Zero-DCE++

Zero-DCE++ You can find more details here: https://li-chongyi.github.io/Proj_Zero-DCE++.html. You can find the details of our CVPR version: https://li

Chongyi Li 157 Dec 23, 2022
Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

FPT_data_centric_competition - Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

Pham Viet Hoang (Harry) 2 Oct 30, 2022
Easily benchmark PyTorch model FLOPs, latency, throughput, max allocated memory and energy consumption

⏱ pytorch-benchmark Easily benchmark model inference FLOPs, latency, throughput, max allocated memory and energy consumption Install pip install pytor

Lukas Hedegaard 21 Dec 22, 2022
DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

Jason Antic 15.8k Jan 04, 2023
A Library for Modelling Probabilistic Hierarchical Graphical Models in PyTorch

A Library for Modelling Probabilistic Hierarchical Graphical Models in PyTorch

Korbinian Pöppel 47 Nov 28, 2022
Code for the paper: On Pathologies in KL-Regularized Reinforcement Learning from Expert Demonstrations

Non-Parametric Prior Actor-Critic (N-PPAC) This repository contains the code for On Pathologies in KL-Regularized Reinforcement Learning from Expert D

Cong Lu 5 May 13, 2022
Transparent Transformer Segmentation

Transparent Transformer Segmentation Introduction This repository contains the data and code for IJCAI 2021 paper Segmenting transparent object in the

谢恩泽 140 Jan 02, 2023
Seeing Dynamic Scene in the Dark: High-Quality Video Dataset with Mechatronic Alignment (ICCV2021)

Seeing Dynamic Scene in the Dark: High-Quality Video Dataset with Mechatronic Alignment This is a pytorch project for the paper Seeing Dynamic Scene i

DV Lab 21 Nov 28, 2022