CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

Overview

CSWin-Transformer

PWC PWC

This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". The code and models for downstream tasks are coming soon.

Introduction

CSWin Transformer (the name CSWin stands for Cross-Shaped Window) is introduced in arxiv, which is a new general-purpose backbone for computer vision. It is a hierarchical Transformer and replaces the traditional full attention with our newly proposed cross-shaped window self-attention. The cross-shaped window self-attention mechanism computes self-attention in the horizontal and vertical stripes in parallel that from a cross-shaped window, with each stripe obtained by splitting the input feature into stripes of equal width. With CSWin, we could realize global attention with a limited computation cost.

CSWin Transformer achieves strong performance on ImageNet classification (87.5 on val with only 97G flops) and ADE20K semantic segmentation (55.7 mIoU on val), surpassing previous models by a large margin.

teaser

Main Results on ImageNet

model pretrain resolution [email protected] #params FLOPs 22K model 1K model
CSWin-T ImageNet-1K 224x224 82.8 23M 4.3G - model
CSWin-S ImageNet-1k 224x224 83.6 35M 6.9G - model
CSWin-B ImageNet-1k 224x224 84.2 78M 15.0G - model
CSWin-B ImageNet-1k 384x384 85.5 78M 47.0G - model
CSWin-L ImageNet-22k 224x224 86.5 173M 31.5G model model
CSWin-L ImageNet-22k 384x384 87.5 173M 96.8G - model

Main Results on Downstream Tasks

COCO Object Detection

backbone Method pretrain lr Schd box mAP mask mAP #params FLOPS
CSwin-T Mask R-CNN ImageNet-1K 3x 49.0 43.6 42M 279G
CSwin-S Mask R-CNN ImageNet-1K 3x 50.0 44.5 54M 342G
CSwin-B Mask R-CNN ImageNet-1K 3x 50.8 44.9 97M 526G
CSwin-T Cascade Mask R-CNN ImageNet-1K 3x 52.5 45.3 80M 757G
CSwin-S Cascade Mask R-CNN ImageNet-1K 3x 53.7 46.4 92M 820G
CSwin-B Cascade Mask R-CNN ImageNet-1K 3x 53.9 46.4 135M 1004G

ADE20K Semantic Segmentation (val)

Backbone Method pretrain Crop Size Lr Schd mIoU mIoU (ms+flip) #params FLOPs
CSwin-T Semantic FPN ImageNet-1K 512x512 80K 48.2 - 26M 202G
CSwin-S Semantic FPN ImageNet-1K 512x512 80K 49.2 - 39M 271G
CSwin-B Semantic FPN ImageNet-1K 512x512 80K 49.9 - 81M 464G
CSwin-T UPerNet ImageNet-1K 512x512 160K 49.3 50.4 60M 959G
CSwin-S UperNet ImageNet-1K 512x512 160K 50.0 50.8 65M 1027G
CSwin-B UperNet ImageNet-1K 512x512 160K 50.8 51.7 109M 1222G
CSwin-B UPerNet ImageNet-22K 640x640 160K 51.8 52.6 109M 1941G
CSwin-L UperNet ImageNet-22K 640x640 160K 53.4 55.7 208M 2745G

Requirements

timm==0.3.4, pytorch>=1.4, opencv, ... , run:

bash install_req.sh

Apex for mixed precision training is used for finetuning. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Train

Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base:

bash train.sh 8 --data <data path> --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2
bash train.sh 8 --data <data path> --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4
bash train.sh 8 --data <data path> --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5

If you want to train our CSWin on images with 384x384 resolution, please use '--img-size 384'.

If the GPU memory is not enough, please use '-b 128 --lr 1e-3 --model-ema-decay 0.99992' or use checkpoint '--use-chk'.

Finetune

Finetune CSWin-Base with 384x384 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_96_24322_base_384 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 384 --warmup-epochs 0 --model-ema-decay 0.9998 --finetune <pretrained 224 model> --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1

Finetune ImageNet-22K pretrained CSWin-Large with 224x224 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_144_24322_large_224 -b 64 --lr 2.5e-4 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9996 --finetune <22k-pretrained model> --epochs 30 --mixup 0.01 --cooldown-epochs 10 --interpolation bicubic  --lr-scale 0.05 --drop-path 0.2 --cutmix 0.3 --use-chk --fine-22k --ema-finetune

If the GPU memory is not enough, please use checkpoint '--use-chk'.

Cite CSWin Transformer

@misc{dong2021cswin,
      title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 
        author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
        year={2021},
        eprint={2107.00652},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
}

Acknowledgement

This repository is built using the timm library and the DeiT repository.

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

Microsoft Open Source Code of Conduct

Contact Information

For help or issues using CSWin Transformer, please submit a GitHub issue.

For other communications related to CSWin Transformer, please contact Jianmin Bao ([email protected]), Dong Chen ([email protected]).

Comments
  • About the patches_resolution of the segmentation model

    About the patches_resolution of the segmentation model

    Hello, this work is interesting but I have some questions about the 'patches_resolution' of the segmentation model. I notice that the long side of the cross-shaped windows is the 'patches_resolution' rather than the real feature resoulution. For example, in the stage-3, the long side is 224 / 16 = 14. Do I understand it correctly? Does that make it impossible to exchannge information outside the 'patches_resolution' ?

    opened by danczs 2
  • The results of downstream task by my realization are poor

    The results of downstream task by my realization are poor

    There are some questions:

    1. the split size is still [1 2 7 7]?
    2. last stage branch_num is 2 or 1 ? The downstream task image resolution in last stages cannot equal to 7(split size). If not 1, the pretrained weights size is not matched
    3. pading is right in my realization ? pad_l = pad_t = 0 pad_r = (W_sp - W % W_sp) % W_sp pad_b = (H_sp - H % H_sp) % H_sp q = q.transpose(-2,-1).contiguous().view(B, H, W, C) k = q.transpose(-2,-1).contiguous().view(B, H, W, C) v = q.transpose(-2,-1).contiguous().view(B, H, W, C) if pad_r > 0 or pad_b > 0: q = F.pad(q, (0, 0, pad_l, pad_r, pad_t, pad_b)) k = F.pad(k, (0, 0, pad_l, pad_r, pad_t, pad_b)) v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = q.shape
    opened by Sunting78 2
  • Experiment setting for semantic segmentation

    Experiment setting for semantic segmentation

    Hi, thank you for the code. I implemented CSwin-T with FPN for semantic segmentation in ADE20K but couldn't reach the mIoU value of 48.2 as mentioned by you in the table. The maximum I could get was 39.9 mIoU, it will be great if you could share the exact experiment settings you used? Thanks

    opened by AnukritiSinghh 2
  • CSWin significantly slower than Swin?

    CSWin significantly slower than Swin?

    Greetings,

    From my benchmarks I have noticed that CSwin seems to be significantly slower than Swin when it comes to inference times, is this the expected behavior? While I can get predictions as fast as 20 miliseconds on Swin Large 384 it takes above 900 milisecond on CSWin_144_24322_large_384.

    I performed tests using FP16, torchscript, optimize_for_inference and torch.inference_mode

    opened by ErenBalatkan 2
  • Pretrained settings for object detection

    Pretrained settings for object detection

    Hi, I'm impressed by your excellent work.

    I have a question.

    I wonder which type of the pre-trained weights (224x224 or 384x384 finetuned) is used for object detection.

    I know both 224x224 and 384x384 are pre-trained on ImageNet-1k.

    opened by youngwanLEE 2
  • Error about building 384 models

    Error about building 384 models

    The code: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L407 shoud be:

    model = CSWinTransformer(img_size=384, patch_size=4, embed_dim=96, depth=[2,4,32,2],
    

    And as the same: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L414

    opened by TingquanGao 2
  • The problem is shown in the figure

    The problem is shown in the figure

    model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4.).cuda().eval() inp = torch.rand(1, 3, 224, 224).cuda() outs = model(inp) for out in outs: print(out.shape)

    RuntimeError: shape '[1, 192, 1, 14, 1, 12]' is invalid for input of size 37632

    why? image

    opened by rui-cf 2
  • about the test result on imagenet-2012

    about the test result on imagenet-2012

    Hi! I've test the CSwin-Tiny-224 released pretrained weight, this is my data transforms during testing:

    DEFAULT_CROP_SIZE = 0.9
    scale_size = int(math.floor(image_size / DEFAULT_CROP_SIZE))
    transform = transforms.Compose(
            [
                transforms.Resize(scale_size, interpolation=3)  # 3: bibubic
                if image_size == 224
                else transforms.Resize(image_size, interpolation=3),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
            ]
        )
    

    I can only get 80.5% on imagenet2012 dataset which is inconsistent with the results as you mentioned in this repo, did I miss some details about the data-augmentation during testing?

    opened by rentainhe 1
  • CSwin Code for Segmentation with MMSegmentation

    CSwin Code for Segmentation with MMSegmentation

    Hi, Thank you for your work! I wonder if you plan to release the mmsegmentation code you used for the downstream segmentation task, just like in Swin-Transformer repo.

    Best,

    opened by WalBouss 1
  • How do you produce Table 9  (ablation on different attention mecahnisms) in the  paper?

    How do you produce Table 9 (ablation on different attention mecahnisms) in the paper?

    Hi, thanks for your nice work. I'm doing some comparison on different attention mechanisms, and want to follow your experimental settings. I meet two problems:

    1. Why the reported mIoU is 41.9 for Swin-T in Table 9, while it is 46.1 in Swin Paper?
    2. Can you provide detailed experimental settings for semantic segmentation and object detection in table 9 ?
    opened by rayleizhu 0
  • about the setting of --use_chk

    about the setting of --use_chk

    give the parameter of --use_chk can launch the torch.utils.checkpoint to save the GPU memory, and I wonder if this could hurt the final performance, thanks a lot!

    opened by go-ahead-maker 0
  • Input image normalization parameters for semantic segmentation.

    Input image normalization parameters for semantic segmentation.

    Hey, guys, cool work!

    Unfortunately, for me it is not quite obvious, which normalization parameters (mean, std) did you use, when trained semantic segmentation model on ADE20K. Are they still IMAGENET_DEFAULT_MEAN or you used another values?

    opened by NikAleksFed 0
  • Using transfer to train over food101

    Using transfer to train over food101

    Hi all! I'm trying to train a model for food101 using the using the CSWin_64_12211_tiny_224 model with its pretrained values. The thing is, during execution it looks like its training from 0 rather than reusing the pretrained weights. By this I mean the initial top5 accuracy is around 5% but my initial thoughts is that it should be higher than this.

    For this I loaded the pretrained model and changed it's classification layer in a separate script and saved it for use as follows

    model = create_model( 'CSWin_64_12211_tiny_224', pretrained=True, num_classes=1000, drop_rate=0.0, drop_connect_rate=None, # DEPRECATED, use drop_path drop_path_rate=0.2, drop_block_rate=None, global_pool=None, bn_tf=False, bn_momentum=None, bn_eps=None, checkpoint_path='', img_size=224, use_chk=True) chk_path = './pretrained/cswin_tiny_224.pth' load_checkpoint(model, chk_path) model.reset_classifier(101, 'max')

    These are some of the runs I tried ` bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1 --use-chk --num-classes 101 --pretrained --finetune ./pretrained/CSWin_64_12211_tiny_224101.pth

    bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --cooldown-epochs 10 --drop-path 0.2 --ema-finetune --cutmix 0.1 --use-chk --num-classes 101 --initial-checkpoint ./pretrained/CSWin_64_12211_tiny_224101.pth --lr-scale 1.0 --output ./full_base `

    Is there something I'm missing or a proper way I should try this?

    Thanks in advance for any help! :)

    opened by andreynz691 0
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

John 9 Sep 18, 2022
Example Of Fine-Tuning BERT For Named-Entity Recognition Task And Preparing For Cloud Deployment Using Flask, React, And Docker

Example Of Fine-Tuning BERT For Named-Entity Recognition Task And Preparing For Cloud Deployment Using Flask, React, And Docker This repository contai

Nikita 12 Dec 14, 2022
Development of IP code based on VIPs and AADM

Sparse Implicit Processes In this repository we include the two different versions of the SIP code developed for the article Sparse Implicit Processes

1 Aug 22, 2022
Research code for the paper "Variational Gibbs inference for statistical estimation from incomplete data".

Variational Gibbs inference (VGI) This repository contains the research code for Simkus, V., Rhodes, B., Gutmann, M. U., 2021. Variational Gibbs infer

Vaidotas Šimkus 1 Apr 08, 2022
PolyphonicFormer: Unified Query Learning for Depth-aware Video Panoptic Segmentation

PolyphonicFormer: Unified Query Learning for Depth-aware Video Panoptic Segmentation Winner method of the ICCV-2021 SemKITTI-DVPS Challenge. [arxiv] [

Yuan Haobo 38 Jan 03, 2023
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
Official implementation of the paper 'High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network' in CVPR 2021

LPTN Paper | Supplementary Material | Poster High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network Ji

372 Dec 26, 2022
Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Shortformer This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the

Ofir Press 138 Apr 15, 2022
This project uses Template Matching technique for object detecting by detection of template image over base image.

Object Detection Project Using OpenCV This project uses Template Matching technique for object detecting by detection the template image over base ima

Pratham Bhatnagar 7 May 29, 2022
This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Introduction This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures. @inproceedings{Wa

Jiaqi Wang 42 Jan 07, 2023
Ian Covert 130 Jan 01, 2023
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
Pytorch and Torch testing code of CartoonGAN

CartoonGAN-Test-Pytorch-Torch Pytorch and Torch testing code of CartoonGAN [Chen et al., CVPR18]. With the released pretrained models by the authors,

Yijun Li 642 Dec 27, 2022
efficient neural audio synthesis in the waveform domain

neural waveshaping synthesis real-time neural audio synthesis in the waveform domain paper • website • colab • audio by Ben Hayes, Charalampos Saitis,

Ben Hayes 169 Dec 23, 2022
Pytorch implementation of our paper accepted by NeurIPS 2021 -- Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme

Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme (NeurIPS2021) (Link) Overview Prerequisites Linu

Shaojie Li 34 Mar 31, 2022
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
Feedback is important: response-aware feedback mechanism for background based conversation

RFM The code for the paper: "Feedback is important: response-aware feedback mechanism for background based conversation." Requirements python 3.7 pyto

Jiatao Chen 2 Sep 29, 2022
Keras implementation of PersonLab for Multi-Person Pose Estimation and Instance Segmentation.

PersonLab This is a Keras implementation of PersonLab for Multi-Person Pose Estimation and Instance Segmentation. The model predicts heatmaps and vari

OCTI 160 Dec 21, 2022
This repository contains the reference implementation for our proposed Convolutional CRFs.

ConvCRF This repository contains the reference implementation for our proposed Convolutional CRFs in PyTorch (Tensorflow planned). The two main entry-

Marvin Teichmann 553 Dec 07, 2022