ML-Decoder: Scalable and Versatile Classification Head

Overview

ML-Decoder: Scalable and Versatile Classification Head

PWC
PWC
PWC


Paper

Official PyTorch Implementation

Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baruch, Asaf Noy
DAMO Academy, Alibaba Group

Abstract

In this paper, we introduce ML-Decoder, a new attention-based classification head. ML-Decoder predicts the existence of class labels via queries, and enables better utilization of spatial data compared to global average pooling. By redesigning the decoder architecture, and using a novel group-decoding scheme, ML-Decoder is highly efficient, and can scale well to thousands of classes. Compared to using a larger backbone, ML-Decoder consistently provides a better speed-accuracy trade-off. ML-Decoder is also versatile - it can be used as a drop-in replacement for various classification heads, and generalize to unseen classes when operated with word queries. Novel query augmentations further improve its generalization ability. Using ML-Decoder, we achieve state-of-the-art results on several classification tasks: on MS-COCO multi-label, we reach 91.4% mAP; on NUS-WIDE zero-shot, we reach 31.1% ZSL mAP; and on ImageNet single-label, we reach with vanilla ResNet50 backbone a new top score of 80.7%, without extra data or distillation.

ML-Decoder Implementation

ML-Decoder implementation is available here. It can be easily integrated into any backbone using this example code:

ml_decoder_head = MLDecoder(num_classes) # initilization

spatial_embeddings = self.backbone(input_image) # backbone generates spatial embeddings      
 
logits = ml_decoder_head(spatial_embeddings) # transfrom spatial embeddings to logits

Training Code

We will share a full reproduction code for the article results.

Multi-label Training Code


A reproduction code for MS-COCO multi-label:

python train.py  \
--data=/home/datasets/coco2014/ \
--model_name=tresnet_l \
--image_size=448

Single-label Training Code

Our single-label training code uses the excellent timm repo. Reproduction code is currently from a fork, we will work toward a full merge to the main repo.

git clone https://github.com/mrT23/pytorch-image-models.git

This is the code for A2 configuration training, with ML-Decoder (--use-ml-decoder-head=1):

python -u -m torch.distributed.launch --nproc_per_node=8 \
--nnodes=1 \
--node_rank=0 \
./train.py \
/data/imagenet/ \
--amp \
-b=256 \
--epochs=300 \
--drop-path=0.05 \
--opt=lamb \
--weight-decay=0.02 \
--sched='cosine' \
--lr=4e-3 \
--warmup-epochs=5 \
--model=resnet50 \
--aa=rand-m7-mstd0.5-inc1 \
--reprob=0.0 \
--remode='pixel' \
--mixup=0.1 \
--cutmix=1.0 \
--aug-repeats 3 \
--bce-target-thresh 0.2 \
--smoothing=0 \
--bce-loss \
--train-interpolation=bicubic \
--use-ml-decoder-head=1

ZSL Training Code

Reproduction code for ZSL is WIP.

Citation

@misc{ridnik2021mldecoder,
      title={ML-Decoder: Scalable and Versatile Classification Head}, 
      author={Tal Ridnik and Gilad Sharir and Avi Ben-Cohen and Emanuel Ben-Baruch and Asaf Noy},
      year={2021},
      eprint={2111.12933},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
PESTO: Switching Point based Dynamic and Relative Positional Encoding for Code-Mixed Languages

PESTO: Switching Point based Dynamic and Relative Positional Encoding for Code-Mixed Languages Abstract NLP applications for code-mixed (CM) or mix-li

Mohsin Ali, Mohammed 1 Nov 12, 2021
Performant, differentiable reinforcement learning

deluca Performant, differentiable reinforcement learning Notes This is pre-alpha software and is undergoing a number of core changes. Updates to follo

Google 114 Dec 27, 2022
A Pytorch Implementation of a continuously rate adjustable learned image compression framework.

GainedVAE A Pytorch Implementation of a continuously rate adjustable learned image compression framework, Gained Variational Autoencoder(GainedVAE). N

39 Dec 24, 2022
Automatic caption evaluation metric based on typicality analysis.

SeMantic and linguistic UndeRstanding Fusion (SMURF) Automatic caption evaluation metric described in the paper "SMURF: SeMantic and linguistic UndeRs

Joshua Feinglass 6 Jan 09, 2022
Learning Calibrated-Guidance for Object Detection in Aerial Images

Learning Calibrated-Guidance for Object Detection in Aerial Images arxiv We propose a simple yet effective Calibrated-Guidance (CG) scheme to enhance

51 Sep 22, 2022
Classification Modeling: Probability of Default

Credit Risk Modeling in Python Introduction: If you've ever applied for a credit card or loan, you know that financial firms process your information

Aktham Momani 2 Nov 07, 2022
PyTorch-based framework for Deep Hedging

PFHedge: Deep Hedging in PyTorch PFHedge is a PyTorch-based framework for Deep Hedging. PFHedge Documentation Neural Network Architecture for Efficien

139 Dec 30, 2022
Animate molecular orbital transitions using Psi4 and Blender

Molecular Orbital Transitions (MOT) Animate molecular orbital transitions using Psi4 and Blender Author: Maximilian Paradiz Dominguez, University of A

3 Feb 01, 2022
FairFuzz: AFL extension targeting rare branches

FairFuzz An AFL extension to increase code coverage by targeting rare branches. FairFuzz has a particular advantage on programs with highly nested str

Caroline Lemieux 222 Nov 16, 2022
Keras implementation of Real-Time Semantic Segmentation on High-Resolution Images

Keras-ICNet [paper] Keras implementation of Real-Time Semantic Segmentation on High-Resolution Images. Training in progress! Requisites Python 3.6.3 K

Aitor Ruano 87 Dec 16, 2022
Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Exact Pareto Optimal solutions for preference based Multi-Objective Optimization

Debabrata Mahapatra 40 Dec 24, 2022
Implementation of the paper All Labels Are Not Created Equal: Enhancing Semi-supervision via Label Grouping and Co-training

SemCo The official pytorch implementation of the paper All Labels Are Not Created Equal: Enhancing Semi-supervision via Label Grouping and Co-training

42 Nov 14, 2022
Automatically download the cwru data set, and then divide it into training data set and test data set

Automatically download the cwru data set, and then divide it into training data set and test data set.自动下载cwru数据集,然后分训练数据集和测试数据集

6 Jun 27, 2022
Configure SRX interfaces with Scrapli

Configure SRX interfaces with Scrapli Overview This example will show how to configure interfaces on Juniper's SRX firewalls. In addition to the Pytho

Calvin Remsburg 1 Jan 07, 2022
Jax/Flax implementation of Variational-DiffWave.

jax-variational-diffwave Jax/Flax implementation of Variational-DiffWave. (Zhifeng Kong et al., 2020, Diederik P. Kingma et al., 2021.) DiffWave with

YoungJoong Kim 37 Dec 16, 2022
Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"

Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations Official repository for paper "Non-Intrusive Speech Intelligibili

Alex McKinney 5 Oct 25, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
Perform zero-order Hankel Transform for an 1D array (float or real valued).

perform zero-order Hankel Transform for an 1D array (float or real valued). An discrete form of Parseval theorem is guaranteed. Suit for iterative problems.

1 Jan 17, 2022
Joint Detection and Identification Feature Learning for Person Search

Person Search Project This repository hosts the code for our paper Joint Detection and Identification Feature Learning for Person Search. The code is

712 Dec 17, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 19, 2021