TResNet: High Performance GPU-Dedicated Architecture

Overview

TResNet: High Performance GPU-Dedicated Architecture

PWC
PWC
PWC
PWC
PWC
PWC
PWC

paperV2 | pretrained models

Official PyTorch Implementation

Tal Ridnik, Hussam Lawen, Asaf Noy, Itamar Friedman, Emanuel Ben Baruch, Gilad Sharir
DAMO Academy, Alibaba Group

Abstract

Many deep learning models, developed in recent years, reach higher ImageNet accuracy than ResNet50, with fewer or comparable FLOPS count. While FLOPs are often seen as a proxy for network efficiency, when measuring actual GPU training and inference throughput, vanilla ResNet50 is usually significantly faster than its recent competitors, offering better throughput-accuracy trade-off. In this work, we introduce a series of architecture modifications that aim to boost neural networks' accuracy, while retaining their GPU training and inference efficiency. We first demonstrate and discuss the bottlenecks induced by FLOPs-optimizations. We then suggest alternative designs that better utilize GPU structure and assets. Finally, we introduce a new family of GPU-dedicated models, called TResNet, which achieve better accuracy and efficiency than previous ConvNets. Using a TResNet model, with similar GPU throughput to ResNet50, we reach 80.7% top-1 accuracy on ImageNet. Our TResNet models also transfer well and achieve state-of-the-art accuracy on competitive datasets such as Stanford cars (96.0%), CIFAR-10 (99.0%), CIFAR-100 (91.5%) and Oxford-Flowers (99.1%). They also perform well on multi-label classification and object detection tasks.

29/11/2021 Update - New article released, offering new classification head with state-of-the-art results

Checkout our new project, Ml-Decoder, which presents a unified classification head for multi-label, single-label and zero-shot tasks. Backbones with ML-Decoder reach SOTA results, while also improving speed-accuracy tradeoff.

23/4/2021 Update - ImageNet21K Pretraining

In a new article we released, we share pretrain weights for TResNet models from ImageNet21K training, that dramatically outperfrom standard pretraining. TResNet-M model, for example, improves its ImageNet-1K score, from 80.7% to 83.1% ! This kind of improvement is consistently achieved on all downstream tasks.

28/8/2020: V2 of TResNet Article Released

Sotabench Comparisons

Comparative results from sotabench benchamrk, demonstartaing that TReNset models give excellent speed-accuracy tradoff:

11/6/2020: V1 of TResNet Article Released

The main change - In addition to single label SOTA results, we also added top results for multi-label classification and object detection tasks, using TResNet. For example, we set a new SOTA record for MS-COCO multi-label dataset, surpassing the previous top results by more than 2.5% mAP !

Bacbkone mAP
KSSNet (previous SOTA) 83.7
TResNet-L 86.4

2/6/2020: CVPR-Kaggle competitions

We participated and won top places in two major CVPR-Kaggle competitions:

  • 2nd place in Herbarium 2020 competition, out of 153 teams.
  • 7th place in Plant-Pathology 2020 competition, out of 1317 teams.

    TResNet was a vital part of our solution for both competitions, allowing us to work on high resolutions and reach top scores while doing fast and efficient experiments.

Main Article Results

TResNet Models

TResNet models accuracy and GPU throughput on ImageNet, compared to ResNet50. All measurements were done on Nvidia V100 GPU, with mixed precision. All models are trained on input resolution of 224.

Models Top Training Speed
(img/sec)
Top Inference Speed
(img/sec)
Max Train Batch Size Top-1 Acc.
ResNet50 805 2830 288 79.0
EfficientNetB1 440 2740 196 79.2
TResNet-M 730 2930 512 80.8
TResNet-L 345 1390 316 81.5
TResNet-XL 250 1060 240 82.0

Comparison To Other Networks

Comparison of ResNet50 to top modern networks, with similar top-1 ImageNet accuracy. All measurements were done on Nvidia V100 GPU with mixed precision. For gaining optimal speeds, training and inference were measured on 90% of maximal possible batch size. Except TResNet-M, all the models' ImageNet scores were taken from the public repository, which specialized in providing top implementations for modern networks. Except EfficientNet-B1, which has input resolution of 240, all other models have input resolution of 224.

Model Top Training Speed
(img/sec)
Top Inference Speed
(img/sec)
Top-1 Acc. Flops[G]
ResNet50 805 2830 79.0 4.1
ResNet50-D 600 2670 79.3 4.4
ResNeXt50 490 1940 79.4 4.3
EfficientNetB1 440 2740 79.2 0.6
SEResNeXt50 400 1770 79.9 4.3
MixNet-L 400 1400 79.0 0.5
TResNet-M 730 2930 80.8 5.5


Transfer Learning SotA Results

Comparison of TResNet to state-of-the-art models on transfer learning datasets (only ImageNet-based transfer learning results). Models inference speed is measured on a mixed precision V100 GPU. Since no official implementation of Gpipe was provided, its inference speed is unknown

Dataset Model Top-1
Acc.
Speed
img/sec
Input
CIFAR-10 Gpipe 99.0 - 480
TResNet-XL 99.0 1060 224
CIFAR-100 EfficientNet-B7 91.7 70 600
TResNet-XL 91.5 1060 224
Stanford Cars EfficientNet-B7 94.7 70 600
TResNet-L 96.0 500 368
Oxford-Flowers EfficientNet-B7 98.8 70 600
TResNet-L 99.1 500 368

Reproduce Article Scores

We provide code for reproducing the validation top-1 score of TResNet models on ImageNet. First, download pretrained models from here.

Then, run the infer.py script. For example, for tresnet_m (input size 224) run:

python -m infer.py \
--val_dir=/path/to/imagenet_val_folder \
--model_path=/model/path/to/tresnet_m.pth \
--model_name=tresnet_m
--input_size=224

TResNet Training

Due to IP limitations, we do not provide the exact training code that was used to obtain the article results.

However, TResNet is now an integral part of the popular rwightman / pytorch-image-models repo. Using that repo, you can reach very similar results to the one stated in the article.

For example, training tresnet_m on rwightman / pytorch-image-models with the command line:

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 --node_rank=0 ./train.py /data/imagenet/ \
-b=190 --lr=0.6 --model-ema --aa=rand-m9-mstd0.5-inc1 \
--num-gpu=8 -j=16 --amp \
--model=tresnet_m --epochs=300 --mixup=0.2 \
--sched='cosine' --reprob=0.4 --remode=pixel

gave accuracy of 80.5%.

Also, during the merge request, we had interesting discussions and insights regarding TResNet design. I am attaching a pdf version the mentioned discussions. They can shed more light on TResNet design considerations and directions for the future.

TResNet discussion and insights

(taken with permission from here)

Tips For Working With Inplace-ABN

See INPLACE_ABN_TIPS.

Citation

@misc{ridnik2020tresnet,
    title={TResNet: High Performance GPU-Dedicated Architecture},
    author={Tal Ridnik and Hussam Lawen and Asaf Noy and Itamar Friedman},
    year={2020},
    eprint={2003.13630},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Contact

Feel free to contact me if there are any questions or issues (Tal Ridnik, [email protected]).

Benchmark for the generalization of 3D machine learning models across different remeshing/samplings of a surface.

Discretization Robust Correspondence Benchmark One challenge of machine learning on 3D surfaces is that there are many different representations/sampl

Nicholas Sharp 10 Sep 30, 2022
NAACL'2021: Factual Probing Is [MASK]: Learning vs. Learning to Recall

OptiPrompt This is the PyTorch implementation of the paper Factual Probing Is [MASK]: Learning vs. Learning to Recall. We propose OptiPrompt, a simple

Princeton Natural Language Processing 150 Dec 20, 2022
The Adapter-Bot: All-In-One Controllable Conversational Model

The Adapter-Bot: All-In-One Controllable Conversational Model This is the implementation of the paper: The Adapter-Bot: All-In-One Controllable Conver

CAiRE 37 Nov 04, 2022
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 1.2k Jan 09, 2023
SSD-based Object Detection in PyTorch

SSD-based Object Detection in PyTorch 서강대학교 현대모비스 SW 프로그램에서 진행한 인공지능 프로젝트입니다. Jetson nano를 이용해 pre-trained network를 fine tuning시켜 차량 및 신호등 인식을 구현하였습니다

Haneul Kim 1 Nov 16, 2021
OpenDILab RL Kubernetes Custom Resource and Operator Lib

DI Orchestrator DI Orchestrator is designed to manage DI (Decision Intelligence) jobs using Kubernetes Custom Resource and Operator. Prerequisites A w

OpenDILab 205 Dec 29, 2022
A Robust Unsupervised Ensemble of Feature-Based Explanations using Restricted Boltzmann Machines

A Robust Unsupervised Ensemble of Feature-Based Explanations using Restricted Boltzmann Machines Understanding the results of deep neural networks is

Johan van den Heuvel 2 Dec 13, 2021
Code for "Localization with Sampling-Argmax", NeurIPS 2021

Localization with Sampling-Argmax [Paper] [arXiv] [Project Page] Localization with Sampling-Argmax Jiefeng Li, Tong Chen, Ruiqi Shi, Yujing Lou, Yong-

JeffLi 71 Dec 17, 2022
This is the code for CVPR 2021 oral paper: Jigsaw Clustering for Unsupervised Visual Representation Learning

JigsawClustering Jigsaw Clustering for Unsupervised Visual Representation Learning Pengguang Chen, Shu Liu, Jiaya Jia Introduction This project provid

DV Lab 73 Sep 18, 2022
clustering moroccan stocks time series data using k-means with dtw (dynamic time warping)

Moroccan Stocks Clustering Context Hey! we don't always have to forecast time series am I right ? We use k-means to cluster about 70 moroccan stock pr

Ayman Lafaz 7 Oct 18, 2022
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
An abstraction layer for mathematical optimization solvers.

MathOptInterface Documentation Build Status Social An abstraction layer for mathematical optimization solvers. Replaces MathProgBase. Citing MathOptIn

JuMP-dev 284 Jan 04, 2023
Official implementation of Unfolded Deep Kernel Estimation for Blind Image Super-resolution.

Unfolded Deep Kernel Estimation for Blind Image Super-resolution Hongyi Zheng, Hongwei Yong, Lei Zhang, "Unfolded Deep Kernel Estimation for Blind Ima

Z80 15 Dec 26, 2022
StrongSORT: Make DeepSORT Great Again

StrongSORT StrongSORT: Make DeepSORT Great Again StrongSORT: Make DeepSORT Great Again Yunhao Du, Yang Song, Bo Yang, Yanyun Zhao arxiv 2202.13514 Abs

369 Jan 04, 2023
Record radiologists' eye gaze when they are labeling images.

Record radiologists' eye gaze when they are labeling images. Read for installation, usage, and deep learning examples. Why use MicEye Versatile As a l

24 Nov 03, 2022
Implements a fake news detection program using classifiers.

Fake news detection Implements a fake news detection program using classifiers for Data Mining course at UoA. Description The project is the categoriz

Apostolos Karvelas 1 Jan 09, 2022
HeartRate detector with ArduinoandPython - Use Arduino and Python create a heartrate detector.

Syllabus of Contents Syllabus of Contents Introduction Of Project Features Develop With Python code introduction Installation License Developer Contac

1 Jan 05, 2022
This implements one of result networks from Large-scale evolution of image classifiers

Exotic structured image classifier This implements one of result networks from Large-scale evolution of image classifiers by Esteban Real, et. al. Req

54 Nov 25, 2022
Simple converter for deploying Stable-Baselines3 model to TFLite and/or Coral

Running SB3 developed agents on TFLite or Coral Introduction I've been using Stable-Baselines3 to train agents against some custom Gyms, some of which

Gary Briggs 16 Oct 11, 2022