Tensorflow implementation of Swin Transformer model.

Overview

Swin Transformer (Tensorflow)

Tensorflow reimplementation of Swin Transformer model.

Based on Official Pytorch implementation. image

Requirements

  • tensorflow >= 2.4.1

Pretrained Swin Transformer Checkpoints

ImageNet-1K and ImageNet-22K Pretrained Checkpoints

name pretrain resolution [email protected] #params model
swin_tiny_224 ImageNet-1K 224x224 81.2 28M github
swin_small_224 ImageNet-1K 224x224 83.2 50M github
swin_base_224 ImageNet-22K 224x224 85.2 88M github
swin_base_384 ImageNet-22K 384x384 86.4 88M github
swin_large_224 ImageNet-22K 224x224 86.3 197M github
swin_large_384 ImageNet-22K 384x384 87.3 197M github

Examples

Initializing the model:

from swintransformer import SwinTransformer

model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=True, pretrained=False)

You can use a pretrained model like this:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

If you use a pretrained model with TPU on kaggle, specify use_tpu option:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

Example: TPU training on Kaggle

Citation

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}
Comments
  • no module name 'swintransformer' error

    no module name 'swintransformer' error

    I wounder where the from swintransformer import SwinTransformer come from? I tried to pip install it, it also said that there is no such module. How can I overcome this problem?

    opened by HunarAA 2
  • Pretrained Swin-Transformer for multiple output

    Pretrained Swin-Transformer for multiple output

    Hi rishigami,

    Thank you for the implementation in Tensorflow. I am trying to use the Swin Transformer for a classification problem with multiple outputs. In your guide on how to use a pertained model you put it in a Sequential mode, but in this way I am not able to stack multiple dense layer for the multiple classification, could you help me understand how can I adapt your TF code to my problem, using it in a Functional API way maybe?

    opened by imanuelroz 2
  • NotImplementedError during model save

    NotImplementedError during model save

    I have defined a model as follows:

    def buildModel(LR = LR):
        backbone = SwinTransformer('swin_large_224', num_classes=None, include_top=False, pretrained=True, use_tpu=False)
        
        inp = L.Input(shape=(224,224,3))
        emb = backbone(inp)
        out = L.Dense(1,activation="relu")(emb)
        
        model = tf.keras.Model(inputs=inp,outputs=out)
        optimizer = tf.keras.optimizers.Adam(lr = LR)
        model.compile(loss="mse",optimizer=optimizer,metrics=[tf.keras.metrics.RootMeanSquaredError()])
        return model
    

    Now when I save this model using model.save("./model.hdf5") I get the following error:

    NotImplementedError                       Traceback (most recent call last)
    /tmp/ipykernel_43/131311624.py in <module>
    ----> 1 model.save("model.hdf5")
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
       2000     # pylint: enable=line-too-long
       2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
    -> 2002                     signatures, options, save_traces)
       2003 
       2004   def save_weights(self,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
        152           'or using `save_weights`.')
        153     hdf5_format.save_model_to_hdf5(
    --> 154         model, filepath, overwrite, include_optimizer)
        155   else:
        156     saved_model_save.save(model, filepath, overwrite, include_optimizer,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
        113 
        114   try:
    --> 115     model_metadata = saving_utils.model_metadata(model, include_optimizer)
        116     for k, v in model_metadata.items():
        117       if isinstance(v, (dict, list, tuple)):
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        156   except NotImplementedError as e:
        157     if require_config:
    --> 158       raise e
        159 
        160   metadata = dict(
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        153   model_config = {'class_name': model.__class__.__name__}
        154   try:
    --> 155     model_config['config'] = model.get_config()
        156   except NotImplementedError as e:
        157     if require_config:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_config(self)
        648 
        649   def get_config(self):
    --> 650     return copy.deepcopy(get_network_config(self))
        651 
        652   @classmethod
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_network_config(network, serialize_layer_fn)
       1347         filtered_inbound_nodes.append(node_data)
       1348 
    -> 1349     layer_config = serialize_layer_fn(layer)
       1350     layer_config['name'] = layer.name
       1351     layer_config['inbound_nodes'] = filtered_inbound_nodes
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        248         return serialize_keras_class_and_config(
        249             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
    --> 250       raise e
        251     serialization_config = {}
        252     for key, item in config.items():
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        243     name = get_registered_name(instance.__class__)
        244     try:
    --> 245       config = instance.get_config()
        246     except NotImplementedError as e:
        247       if _SKIP_FAILED_SERIALIZATION:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in get_config(self)
       2252 
       2253   def get_config(self):
    -> 2254     raise NotImplementedError
       2255 
       2256   @classmethod
    
    NotImplementedError: 
    
    opened by Bibhash123 1
  • Invalid argument

    Invalid argument

    this is my basic model

    
    with tpu_strategy.scope():
        model = tf.keras.Sequential([
                            tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(data, mode="torch"), 
                                                                input_shape=[224,224, 3]),
                            SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
                            tf.keras.layers.Dense(1, activation='sigmoid')
                                            ])
    
    model.compile(loss = tf.keras.losses.BinaryCrossentropy(),
                              optimizer = tf.keras.optimizers.Adam(learning_rate=cfg['LEARNING_RATE']),
                              metrics   = RMSE)
    
    

    I am getting this error,

    (3) Invalid argument: {{function_node __inference_train_function_705020}} Reshape's input dynamic dimension is decomposed into multiple output dynamic dimensions, but the constraint is ambiguous and XLA can't infer the output dimension %reshape.12202 = f32[256,144,576]{2,1,0} reshape(f32[36864,576]{1,0} %transpose.12194), metadata={op_type="Reshape" op_name="sequential_40/swin_large_384/sequential_39/basic_layer_28/sequential_35/swin_transformer_block_169/window_attention_169/layers0/blocks1/attn/qkv/Tensordot"}. [[{{node TPUReplicate/_compile/_17658394825749957328/_4}}]] [[tpu_compile_succeeded_assert/_11424487196827204192/_5/_209]]

    opened by AliKayhanAtay 1
  • relative_position_bias_table initialization

    relative_position_bias_table initialization

    Hi, In the official code, relative_position_bias_table is initialized in a truncated normal distribution. Is that part missing in this repo?

    Official code: https://github.com/microsoft/Swin-Transformer/blob/6bbd83ca617db8480b2fb9b335c476ffaf5afb1a/models/swin_transformer.py#L110

    This implem https://github.com/rishigami/Swin-Transformer-TF/blob/8986ca7b0e1f984437db2d8f17e0ecd87fadcd4f/swintransformer/model.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L70

    opened by gathierry 1
  • Image size other than default ones doesn't work

    Image size other than default ones doesn't work

    • Notebook: https://colab.research.google.com/drive/1nqYkQCUzShkVdqGxW4TyMrtAb0n5MBZR#scrollTo=G9ZVlphmqD7d Issue:
    • In swin_tiny_224 I've tried multiple of 224, 512x512, multiple of window_size. But nothing seems to work other than the 224x224.
    • Same goes for swin_large_384, only default size 384x384 works.

    I'm wondering if this is expected behavior or not. Is there any way to make it work for non-square image?

    opened by awsaf49 1
  • Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Tested and working, ie:

    IMAGE_SIZE = [112, 112, 112]
    NUM_CLASSES = 10
    
    model_3d = tf.keras.Sequential([
      swin_transformer_nd.SwinTransformerModel(img_size=IMAGE_SIZE, patch_size=(4, 4, 4), depths=[2, 2, 6]),
      tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    model_3d.compile(tf.keras.optimizers.Adam(), "categorical_crossentropy")
    
    for i in range(100):
        x = np.zeros([1, *IMAGE_SIZE, 1])
        y = tf.zeros([1, NUM_CLASSES])
        
        model_3d.fit(x, y)
        print("Trained on a batch")
    
    opened by MohamadZeina 0
  • Could you provide weights convert script?

    Could you provide weights convert script?

    I tried code and weights you provided, and find the performance is bad. Could you pleaase to provide weights convert script for me to figure out this issue?

    Many thanks

    opened by edwardyehuang 0
  • tf load model is erro

    tf load model is erro

    import tensorflow as tf from swintransformer import SwinTransformer model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]), SwinTransformer('swin_tiny_224', include_top=False, pretrained=True), tf.keras.layers.Dense(NUM_CLASSES, activation='softmax') ])

    tf can't load pre trained model。this step is errro

    opened by jangjiun 0
  • Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Has anyone tried to use the pretrained model with TimeDistributed layer ?

    model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), 
    input_shape=[224,224, 3]), SwinTransformer('swin_base_224', include_top=False, pretrained=True)])
    
    model_f = models.Sequential()
    	model.add(TimeDistributed(model, input_shape= (8,224,224,3)) 
    
    

    I get the following error:

    NotImplementedError: Exception encountered when calling layer "time_distributed" (type TimeDistributed).
    
    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel).
    
    Call arguments received by layer "time_distributed" (type TimeDistributed):
      • inputs=tf.Tensor(shape=(None, 8, 224, 224, 3), dtype=float32)
      • training=False
    
    
    opened by atelili 0
Releases(v0.1-tf-swin-weights)
An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

0 May 06, 2022
Cervix ROI Segmentation Using U-NET

Cervix ROI Segmentation Using U-NET Overview This code illustrate how to segment the ROI in cervical images using U-NET. The ROI here meant to include

Scotty Kwok 35 Sep 14, 2022
The Official TensorFlow Implementation for SPatchGAN (ICCV2021)

SPatchGAN: Official TensorFlow Implementation Paper "SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation"

39 Dec 30, 2022
Web mining module for Python, with tools for scraping, natural language processing, machine learning, network analysis and visualization.

Pattern Pattern is a web mining module for Python. It has tools for: Data Mining: web services (Google, Twitter, Wikipedia), web crawler, HTML DOM par

Computational Linguistics Research Group 8.4k Jan 03, 2023
Development of IP code based on VIPs and AADM

Sparse Implicit Processes In this repository we include the two different versions of the SIP code developed for the article Sparse Implicit Processes

1 Aug 22, 2022
Object detection using yolo-tiny model and opencv used as backend

Object detection Algorithm used : Yolo algorithm Backend : opencv Library required: opencv = 4.5.4-dev' Quick Overview about structure 1) main.py Load

2 Jul 06, 2022
WaveFake: A Data Set to Facilitate Audio DeepFake Detection

WaveFake: A Data Set to Facilitate Audio DeepFake Detection This is the code repository for our NeurIPS 2021 (Track on Datasets and Benchmarks) paper

Chair for Sys­tems Se­cu­ri­ty 27 Dec 22, 2022
Improving Machine Translation Systems via Isotopic Replacement

CAT (Improving Machine Translation Systems via Isotopic Replacement) Machine translation plays an essential role in people’s daily international commu

Zeyu Sun 10 Nov 30, 2022
[Arxiv preprint] Causality-inspired Single-source Domain Generalization for Medical Image Segmentation (code&data-processing pipeline)

Causality-inspired Single-source Domain Generalization for Medical Image Segmentation Arxiv preprint Repository under construction. Might still be bug

Cheng 31 Dec 27, 2022
《A-CNN: Annularly Convolutional Neural Networks on Point Clouds》(2019)

A-CNN: Annularly Convolutional Neural Networks on Point Clouds Created by Artem Komarichev, Zichun Zhong, Jing Hua from Department of Computer Science

Artёm Komarichev 44 Feb 24, 2022
StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks

StackGAN Pytorch implementation Inception score evaluation StackGAN-v2-pytorch Tensorflow implementation for reproducing main results in the paper Sta

Han Zhang 1.8k Dec 21, 2022
Progressive Coordinate Transforms for Monocular 3D Object Detection

Progressive Coordinate Transforms for Monocular 3D Object Detection This repository is the official implementation of PCT. Introduction In this paper,

58 Nov 06, 2022
Efficient Multi Collection Style Transfer Using GAN

Proposed a new model that can make style transfer from single style image, and allow to transfer into multiple different styles in a single model.

Zhaozheng Shen 2 Jan 15, 2022
Code release for ICCV 2021 paper "Anticipative Video Transformer"

Anticipative Video Transformer Ranked first in the Action Anticipation task of the CVPR 2021 EPIC-Kitchens Challenge! (entry: AVT-FB-UT) [project page

Facebook Research 123 Dec 13, 2022
Probabilistic Cross-Modal Embedding (PCME) CVPR 2021

Probabilistic Cross-Modal Embedding (PCME) CVPR 2021 Official Pytorch implementation of PCME | Paper Sanghyuk Chun1 Seong Joon Oh1 Rafael Sampaio de R

NAVER AI 87 Dec 21, 2022
Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization.

Scene Graph Generation Object Detections Ground truth Scene Graph Generated Scene Graph In this visualization, woman sitting on rock is a zero-shot tr

Boris Knyazev 93 Dec 28, 2022
Benchmark for evaluating open-ended generation

OpenMEVA Contributed by Jian Guan, Zhexin Zhang. Thank Jiaxin Wen for DeBugging. OpenMEVA is a benchmark for evaluating open-ended story generation me

25 Nov 15, 2022
A Repository of Community-Driven Natural Instructions

A Repository of Community-Driven Natural Instructions TLDR; this repository maintains a community effort to create a large collection of tasks and the

AI2 244 Jan 04, 2023
Computer Vision Paper Reviews with Key Summary of paper, End to End Code Practice and Jupyter Notebook converted papers

Computer-Vision-Paper-Reviews Computer Vision Paper Reviews with Key Summary along Papers & Codes. Jonathan Choi 2021 The repository provides 100+ Pap

Jonathan Choi 2 Mar 17, 2022
GLNet for Memory-Efficient Segmentation of Ultra-High Resolution Images

GLNet for Memory-Efficient Segmentation of Ultra-High Resolution Images Collaborative Global-Local Networks for Memory-Efficient Segmentation of Ultra-

VITA 298 Dec 12, 2022