PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Related tags

Deep Learningxcit
Overview

Cross-Covariance Image Transformer (XCiT)

PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Linear complexity in time and memory

Our XCiT models has a linear complexity w.r.t number of patches/tokens:

Peak Memory (inference) Millisecond/Image (Inference)

Scaling to high resolution inputs

XCiT can scale to high resolution inputs both due to cheaper compute requirement as well as better adaptability to higher resolution at test time (see Figure 3 in the paper)

Detection and Instance Segmentation for Ultra high resolution images (6000x4000)

Detection and Instance segmentation result for an ultra high resolution image 6000x4000 )

XCiT+DINO: High Res. Self-Attention Visualization 🦖

Our XCiT models with self-supervised training using DINO can obtain high resolution attention maps.

xcit_dino.mp4

Self-Attention visualization per head

Below we show the attention maps for each of the 8 heads separately and we can observe that every head specializes in different semantic aspects of the scene for the foreground as well as the background.

Multi_head.mp4

Getting Started

First, clone the repo

git clone https://github.com/facebookresearch/XCiT.git

Then, you can install the required packages including: Pytorch version 1.7.1, torchvision version 0.8.2 and Timm version 0.4.8

pip install -r requirements.txt

Download and extract the ImageNet dataset. Afterwards, set the --data-path argument to the corresponding extracted ImageNet path.

For full details about all the available arguments, you can use

python main.py --help

For detection and segmentation downstream tasks, please check:


Model Zoo

We provide XCiT models pre-trained weights on ImageNet-1k.

§: distillation

Models with 16x16 patch size

Arch params Model
224 224 § 384 §
top-1 weights top-1 weights top-1 weights
xcit_nano_12_p16 3M 69.9% download 72.2% download 75.4% download
xcit_tiny_12_p16 7M 77.1% download 78.6% download 80.9% download
xcit_tiny_24_p16 12M 79.4% download 80.4% download 82.6% download
xcit_small_12_p16 26M 82.0% download 83.3% download 84.7% download
xcit_small_24_p16 48M 82.6% download 83.9% download 85.1% download
xcit_medium_24_p16 84M 82.7% download 84.3% download 85.4% download
xcit_large_24_p16 189M 82.9% download 84.9% download 85.8% download

Models with 8x8 patch size

Arch params Model
224 224 § 384 §
top-1 weights top-1 weights top-1 weights
xcit_nano_12_p8 3M 73.8% download 76.3% download 77.8% download
xcit_tiny_12_p8 7M 79.7% download 81.2% download 82.4% download
xcit_tiny_24_p8 12M 81.9% download 82.6% download 83.7% download
xcit_small_12_p8 26M 83.4% download 84.2% download 85.1% download
xcit_small_24_p8 48M 83.9% download 84.9% download 85.6% download
xcit_medium_24_p8 84M 83.7% download 85.1% download 85.8% download
xcit_large_24_p8 189M 84.4% download 85.4% download 86.0% download

XCiT + DINO Self-supervised models

Arch params k-nn linear download
xcit_small_12_p16 26M 76.0% 77.8% backbone
xcit_small_12_p8 26M 77.1% 79.2% backbone
xcit_medium_24_p16 84M 76.4% 78.8% backbone
xcit_medium_24_p8 84M 77.9% 80.3% backbone

Training

For training using a single node, use the following command

python -m torch.distributed.launch --nproc_per_node=[NUM_GPUS] --use_env main.py --model [MODEL_KEY] --batch-size [BATCH_SIZE] --drop-path [STOCHASTIC_DEPTH_RATIO] --output_dir [OUTPUT_PATH]

For example, the XCiT-S12/16 model can be trained using the following command

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model xcit_small_12_p16 --batch-size 128 --drop-path 0.05 --output_dir /experiments/xcit_small_12_p16/ --epochs [NUM_EPOCHS]

For multinode training via SLURM you can alternatively use

python run_with_submitit.py --partition [PARTITION_NAME] --nodes 2 --ngpus 8 --model xcit_small_12_p16 --batch-size 64 --drop-path 0.05 --job_dir /experiments/xcit_small_12_p16/ --epochs 400

More details for the hyper-parameters used to train the different models can be found in Table B.1 in the paper.

Evaluation

To evaluate an XCiT model using the checkpoints above or models you trained use the following command:

python main.py --eval --model  --input-size  [--full_crop] --pretrained 

By default we use the --full_crop flag which evaluates the model with a crop ratio of 1.0 instead of 0.875 following CaiT.

For example, the command to evaluate the XCiT-S12/16 using 224x224 images:

python main.py --eval --model xcit_small_12_p16 --input-size 384 --full_crop --pretrained https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth

Acknowledgement

This repository is built using the Timm library and the DeiT repository. The self-supervised training is based on the DINO repository.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Contributing

We actively welcome your pull requests! Please see CONTRIBUTING.md and CODE_OF_CONDUCT.md for more info.

Citation

If you find this repository useful, please consider citing our work:

@misc{elnouby2021xcit,
      title={XCiT: Cross-Covariance Image Transformers}, 
      author={Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Hervé Jegou},
      year={2021},
      journal={arXiv preprint arXiv:2106.09681},
}
Owner
Facebook Research
Facebook Research
unofficial pytorch implementation of RefineGAN

RefineGAN unofficial pytorch implementation of RefineGAN (https://arxiv.org/abs/1709.00753) for CSMRI reconstruction, the official code using tensorpa

xinby17 5 Jul 21, 2022
Introduction to CPM

CPM CPM is an open-source program on large-scale pre-trained models, which is conducted by Beijing Academy of Artificial Intelligence and Tsinghua Uni

Tsinghua AI 136 Dec 23, 2022
State of the Art Neural Networks for Generative Deep Learning

pyradox-generative State of the Art Neural Networks for Generative Deep Learning Table of Contents pyradox-generative Table of Contents Installation U

Ritvik Rastogi 8 Sep 29, 2022
Official implementation for "Low-light Image Enhancement via Breaking Down the Darkness"

Low-light Image Enhancement via Breaking Down the Darkness by Qiming Hu, Xiaojie Guo. 1. Dependencies Python3 PyTorch=1.0 OpenCV-Python, TensorboardX

Qiming Hu 30 Jan 01, 2023
A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.

Torch-RecHub A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend. 安装 pip install torch-rechub 主要特性 scikit-learn风格易用

Mincai Lai 67 Jan 04, 2023
Pytorch Implementation of the paper "Cross-domain Correspondence Learning for Exemplar-based Image Translation"

CoCosNet Pytorch Implementation of the paper "Cross-domain Correspondence Learning for Exemplar-based Image Translation" (CVPR 2020 oral). Update: 202

Lingbo Yang 38 Sep 22, 2021
MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

pytorch-made This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn

Andrej 498 Dec 30, 2022
Numerai tournament example scripts using NN and optuna

numerai_NN_example Numerai tournament example scripts using pytorch NN, lightGBM and optuna https://numer.ai/tournament Performance of my model based

Takahiro Maeda 12 Oct 10, 2022
VGGFace2-HQ - A high resolution face dataset for face editing purpose

The first open source high resolution dataset for face swapping!!! A high resolution version of VGGFace2 for academic face editing purpose

Naiyuan Liu 232 Dec 29, 2022
Self-Supervised Contrastive Learning of Music Spectrograms

Self-Supervised Music Analysis Self-Supervised Contrastive Learning of Music Spectrograms Dataset Songs on the Billboard Year End Hot 100 were collect

27 Dec 10, 2022
A Blender python script for getting asset browser custom preview images for objects and collections.

asset_snapshot A Blender python script for getting asset browser custom preview images for objects and collections. Installation: Click the code butto

Johnny Matthews 44 Nov 29, 2022
A series of Jupyter notebooks with Chinese comment that walk you through the fundamentals of Machine Learning and Deep Learning in python using Scikit-Learn and TensorFlow.

Hands-on-Machine-Learning 目的 这份笔记旨在帮助中文学习者以一种较快较系统的方式入门机器学习, 是在学习Hands-on Machine Learning with Scikit-Learn and TensorFlow这本书的 时候做的个人笔记: 此项目的可取之处 原书的

Baymax 1.5k Dec 21, 2022
Code for the paper "PortraitNet: Real-time portrait segmentation network for mobile device" @ CAD&Graphics2019

PortraitNet Code for the paper "PortraitNet: Real-time portrait segmentation network for mobile device". @ CAD&Graphics 2019 Introduction We propose a

265 Dec 01, 2022
CvT2DistilGPT2 is an encoder-to-decoder model that was developed for chest X-ray report generation.

CvT2DistilGPT2 Improving Chest X-Ray Report Generation by Leveraging Warm-Starting This repository houses the implementation of CvT2DistilGPT2 from [1

The Australian e-Health Research Centre 21 Dec 28, 2022
Repository of best practices for deep learning in Julia, inspired by fastai

FastAI Docs: Stable | Dev FastAI.jl is inspired by fastai, and is a repository of best practices for deep learning in Julia. Its goal is to easily ena

FluxML 532 Jan 02, 2023
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 3 Aug 20, 2022
PyTorch implementation of MSBG hearing loss model and MBSTOI intelligibility metric

PyTorch implementation of MSBG hearing loss model and MBSTOI intelligibility metric This repository contains the implementation of MSBG hearing loss m

BUT <a href=[email protected]"> 9 Nov 08, 2022
Network Enhancement implementation in pytorch

network_enahncement_pytorch Network Enhancement implementation in pytorch Research paper Network Enhancement: a general method to denoise weighted bio

Yen 1 Nov 12, 2021
Tools to create pixel-wise object masks, bounding box labels (2D and 3D) and 3D object model (PLY triangle mesh) for object sequences filmed with an RGB-D camera.

Tools to create pixel-wise object masks, bounding box labels (2D and 3D) and 3D object model (PLY triangle mesh) for object sequences filmed with an RGB-D camera. This project prepares training and t

305 Dec 16, 2022
Weak-supervised Visual Geo-localization via Attention-based Knowledge Distillation

Weak-supervised Visual Geo-localization via Attention-based Knowledge Distillation Introduction WAKD is a PyTorch implementation for our ICPR-2022 pap

2 Oct 20, 2022