Tensorflow implementation of Swin Transformer model.

Overview

Swin Transformer (Tensorflow)

Tensorflow reimplementation of Swin Transformer model.

Based on Official Pytorch implementation. image

Requirements

  • tensorflow >= 2.4.1

Pretrained Swin Transformer Checkpoints

ImageNet-1K and ImageNet-22K Pretrained Checkpoints

name pretrain resolution [email protected] #params model
swin_tiny_224 ImageNet-1K 224x224 81.2 28M github
swin_small_224 ImageNet-1K 224x224 83.2 50M github
swin_base_224 ImageNet-22K 224x224 85.2 88M github
swin_base_384 ImageNet-22K 384x384 86.4 88M github
swin_large_224 ImageNet-22K 224x224 86.3 197M github
swin_large_384 ImageNet-22K 384x384 87.3 197M github

Examples

Initializing the model:

from swintransformer import SwinTransformer

model = SwinTransformer('swin_tiny_224', num_classes=1000, include_top=True, pretrained=False)

You can use a pretrained model like this:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

If you use a pretrained model with TPU on kaggle, specify use_tpu option:

import tensorflow as tf
from swintransformer import SwinTransformer

model = tf.keras.Sequential([
  tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]),
  SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

Example: TPU training on Kaggle

Citation

@article{liu2021Swin,
  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
  journal={arXiv preprint arXiv:2103.14030},
  year={2021}
}
Comments
  • no module name 'swintransformer' error

    no module name 'swintransformer' error

    I wounder where the from swintransformer import SwinTransformer come from? I tried to pip install it, it also said that there is no such module. How can I overcome this problem?

    opened by HunarAA 2
  • Pretrained Swin-Transformer for multiple output

    Pretrained Swin-Transformer for multiple output

    Hi rishigami,

    Thank you for the implementation in Tensorflow. I am trying to use the Swin Transformer for a classification problem with multiple outputs. In your guide on how to use a pertained model you put it in a Sequential mode, but in this way I am not able to stack multiple dense layer for the multiple classification, could you help me understand how can I adapt your TF code to my problem, using it in a Functional API way maybe?

    opened by imanuelroz 2
  • NotImplementedError during model save

    NotImplementedError during model save

    I have defined a model as follows:

    def buildModel(LR = LR):
        backbone = SwinTransformer('swin_large_224', num_classes=None, include_top=False, pretrained=True, use_tpu=False)
        
        inp = L.Input(shape=(224,224,3))
        emb = backbone(inp)
        out = L.Dense(1,activation="relu")(emb)
        
        model = tf.keras.Model(inputs=inp,outputs=out)
        optimizer = tf.keras.optimizers.Adam(lr = LR)
        model.compile(loss="mse",optimizer=optimizer,metrics=[tf.keras.metrics.RootMeanSquaredError()])
        return model
    

    Now when I save this model using model.save("./model.hdf5") I get the following error:

    NotImplementedError                       Traceback (most recent call last)
    /tmp/ipykernel_43/131311624.py in <module>
    ----> 1 model.save("model.hdf5")
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
       2000     # pylint: enable=line-too-long
       2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
    -> 2002                     signatures, options, save_traces)
       2003 
       2004   def save_weights(self,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
        152           'or using `save_weights`.')
        153     hdf5_format.save_model_to_hdf5(
    --> 154         model, filepath, overwrite, include_optimizer)
        155   else:
        156     saved_model_save.save(model, filepath, overwrite, include_optimizer,
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
        113 
        114   try:
    --> 115     model_metadata = saving_utils.model_metadata(model, include_optimizer)
        116     for k, v in model_metadata.items():
        117       if isinstance(v, (dict, list, tuple)):
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        156   except NotImplementedError as e:
        157     if require_config:
    --> 158       raise e
        159 
        160   metadata = dict(
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
        153   model_config = {'class_name': model.__class__.__name__}
        154   try:
    --> 155     model_config['config'] = model.get_config()
        156   except NotImplementedError as e:
        157     if require_config:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_config(self)
        648 
        649   def get_config(self):
    --> 650     return copy.deepcopy(get_network_config(self))
        651 
        652   @classmethod
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_network_config(network, serialize_layer_fn)
       1347         filtered_inbound_nodes.append(node_data)
       1348 
    -> 1349     layer_config = serialize_layer_fn(layer)
       1350     layer_config['name'] = layer.name
       1351     layer_config['inbound_nodes'] = filtered_inbound_nodes
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        248         return serialize_keras_class_and_config(
        249             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
    --> 250       raise e
        251     serialization_config = {}
        252     for key, item in config.items():
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
        243     name = get_registered_name(instance.__class__)
        244     try:
    --> 245       config = instance.get_config()
        246     except NotImplementedError as e:
        247       if _SKIP_FAILED_SERIALIZATION:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in get_config(self)
       2252 
       2253   def get_config(self):
    -> 2254     raise NotImplementedError
       2255 
       2256   @classmethod
    
    NotImplementedError: 
    
    opened by Bibhash123 1
  • Invalid argument

    Invalid argument

    this is my basic model

    
    with tpu_strategy.scope():
        model = tf.keras.Sequential([
                            tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(data, mode="torch"), 
                                                                input_shape=[224,224, 3]),
                            SwinTransformer('swin_tiny_224', include_top=False, pretrained=True, use_tpu=True),
                            tf.keras.layers.Dense(1, activation='sigmoid')
                                            ])
    
    model.compile(loss = tf.keras.losses.BinaryCrossentropy(),
                              optimizer = tf.keras.optimizers.Adam(learning_rate=cfg['LEARNING_RATE']),
                              metrics   = RMSE)
    
    

    I am getting this error,

    (3) Invalid argument: {{function_node __inference_train_function_705020}} Reshape's input dynamic dimension is decomposed into multiple output dynamic dimensions, but the constraint is ambiguous and XLA can't infer the output dimension %reshape.12202 = f32[256,144,576]{2,1,0} reshape(f32[36864,576]{1,0} %transpose.12194), metadata={op_type="Reshape" op_name="sequential_40/swin_large_384/sequential_39/basic_layer_28/sequential_35/swin_transformer_block_169/window_attention_169/layers0/blocks1/attn/qkv/Tensordot"}. [[{{node TPUReplicate/_compile/_17658394825749957328/_4}}]] [[tpu_compile_succeeded_assert/_11424487196827204192/_5/_209]]

    opened by AliKayhanAtay 1
  • relative_position_bias_table initialization

    relative_position_bias_table initialization

    Hi, In the official code, relative_position_bias_table is initialized in a truncated normal distribution. Is that part missing in this repo?

    Official code: https://github.com/microsoft/Swin-Transformer/blob/6bbd83ca617db8480b2fb9b335c476ffaf5afb1a/models/swin_transformer.py#L110

    This implem https://github.com/rishigami/Swin-Transformer-TF/blob/8986ca7b0e1f984437db2d8f17e0ecd87fadcd4f/swintransformer/model.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L70

    opened by gathierry 1
  • Image size other than default ones doesn't work

    Image size other than default ones doesn't work

    • Notebook: https://colab.research.google.com/drive/1nqYkQCUzShkVdqGxW4TyMrtAb0n5MBZR#scrollTo=G9ZVlphmqD7d Issue:
    • In swin_tiny_224 I've tried multiple of 224, 512x512, multiple of window_size. But nothing seems to work other than the 224x224.
    • Same goes for swin_large_384, only default size 384x384 works.

    I'm wondering if this is expected behavior or not. Is there any way to make it work for non-square image?

    opened by awsaf49 1
  • Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Added 3D support for SwinTransformerModel, ie for medical imaging tasks

    Tested and working, ie:

    IMAGE_SIZE = [112, 112, 112]
    NUM_CLASSES = 10
    
    model_3d = tf.keras.Sequential([
      swin_transformer_nd.SwinTransformerModel(img_size=IMAGE_SIZE, patch_size=(4, 4, 4), depths=[2, 2, 6]),
      tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    model_3d.compile(tf.keras.optimizers.Adam(), "categorical_crossentropy")
    
    for i in range(100):
        x = np.zeros([1, *IMAGE_SIZE, 1])
        y = tf.zeros([1, NUM_CLASSES])
        
        model_3d.fit(x, y)
        print("Trained on a batch")
    
    opened by MohamadZeina 0
  • Could you provide weights convert script?

    Could you provide weights convert script?

    I tried code and weights you provided, and find the performance is bad. Could you pleaase to provide weights convert script for me to figure out this issue?

    Many thanks

    opened by edwardyehuang 0
  • tf load model is erro

    tf load model is erro

    import tensorflow as tf from swintransformer import SwinTransformer model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), input_shape=[*IMAGE_SIZE, 3]), SwinTransformer('swin_tiny_224', include_top=False, pretrained=True), tf.keras.layers.Dense(NUM_CLASSES, activation='softmax') ])

    tf can't load pre trained model。this step is errro

    opened by jangjiun 0
  • Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel)

    Has anyone tried to use the pretrained model with TimeDistributed layer ?

    model = tf.keras.Sequential([ tf.keras.layers.Lambda(lambda data: tf.keras.applications.imagenet_utils.preprocess_input(tf.cast(data, tf.float32), mode="torch"), 
    input_shape=[224,224, 3]), SwinTransformer('swin_base_224', include_top=False, pretrained=True)])
    
    model_f = models.Sequential()
    	model.add(TimeDistributed(model, input_shape= (8,224,224,3)) 
    
    

    I get the following error:

    NotImplementedError: Exception encountered when calling layer "time_distributed" (type TimeDistributed).
    
    Please run in eager mode or implement the `compute_output_shape` method on your layer (SwinTransformerModel).
    
    Call arguments received by layer "time_distributed" (type TimeDistributed):
      • inputs=tf.Tensor(shape=(None, 8, 224, 224, 3), dtype=float32)
      • training=False
    
    
    opened by atelili 0
Releases(v0.1-tf-swin-weights)
SMIS - Semantically Multi-modal Image Synthesis(CVPR 2020)

Semantically Multi-modal Image Synthesis Project page / Paper / Demo Semantically Multi-modal Image Synthesis(CVPR2020). Zhen Zhu, Zhiliang Xu, Anshen

316 Dec 01, 2022
[CVPR'21] Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration

Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration This repository contains the implementation of our paper Locally Aware Pi

sfwang 70 Dec 19, 2022
PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization using Augmented-Self Reference and Dense Semantic Correspondence) and pre-trained model on ImageNet dataset

Reference-Based-Sketch-Image-Colorization-ImageNet This is a PyTorch implementation of CVPR 2020 paper (Reference-Based Sketch Image Colorization usin

Yuzhi ZHAO 11 Jul 28, 2022
Xi Dongbo 78 Nov 29, 2022
NVIDIA Merlin is an open source library providing end-to-end GPU-accelerated recommender systems, from feature engineering and preprocessing to training deep learning models and running inference in production.

NVIDIA Merlin NVIDIA Merlin is an open source library designed to accelerate recommender systems on NVIDIA’s GPUs. It enables data scientists, machine

419 Jan 03, 2023
Evaluation toolkit of the informative tracking benchmark comprising 9 scenarios, 180 diverse videos, and new challenges.

Informative-tracking-benchmark Informative tracking benchmark (ITB) higher diversity. It contains 9 representative scenarios and 180 diverse videos. m

Xin Li 15 Nov 26, 2022
The code of NeurIPS 2021 paper "Scalable Rule-Based Representation Learning for Interpretable Classification".

Rule-based Representation Learner This is a PyTorch implementation of Rule-based Representation Learner (RRL) as described in NeurIPS 2021 paper: Scal

Zhuo Wang 53 Dec 17, 2022
Official code for 'Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentationon Complex Urban Driving Scenes'

PEBAL This repo contains the Pytorch implementation of our paper: Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentation on Complex Urb

Yu Tian 117 Jan 03, 2023
H&M Fashion Image similarity search with Weaviate and DocArray

H&M Fashion Image similarity search with Weaviate and DocArray This example shows how to do image similarity search using DocArray and Weaviate as Doc

Laura Ham 18 Aug 11, 2022
License Plate Detection Application

LicensePlate_Project 🚗 🚙 [Project] 2021.02 ~ 2021.09 License Plate Detection Application Overview 1. 데이터 수집 및 라벨링 차량 번호판 이미지를 직접 수집하여 각 이미지에 대해 '번호판

4 Oct 10, 2022
Autonomous Ground Vehicle Navigation and Control Simulation Examples in Python

Autonomous Ground Vehicle Navigation and Control Simulation Examples in Python THIS PROJECT IS CURRENTLY A WORK IN PROGRESS AND THUS THIS REPOSITORY I

Joshua Marshall 14 Dec 31, 2022
This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports"

Introduction: X-Ray Report Generation This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports". O

no name 36 Dec 16, 2022
Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

CoulombGas This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX,

FermiFlow 9 Mar 03, 2022
Cross-modal Retrieval using Transformer Encoder Reasoning Networks (TERN). With use of Metric Learning and FAISS for fast similarity search on GPU

Cross-modal Retrieval using Transformer Encoder Reasoning Networks This project reimplements the idea from "Transformer Reasoning Network for Image-Te

Minh-Khoi Pham 5 Nov 05, 2022
[ICCV 2021 Oral] Just Ask: Learning to Answer Questions from Millions of Narrated Videos

Just Ask: Learning to Answer Questions from Millions of Narrated Videos Webpage • Demo • Paper This repository provides the code for our paper, includ

Antoine Yang 87 Jan 05, 2023
Julia package for multiway (inverse) covariance estimation.

TensorGraphicalModels TensorGraphicalModels.jl is a suite of Julia tools for estimating high-dimensional multiway (tensor-variate) covariance and inve

Wayne Wang 3 Sep 23, 2022
Doing fast searching of nearest neighbors in high dimensional spaces is an increasingly important problem

Benchmarking nearest neighbors Doing fast searching of nearest neighbors in high dimensional spaces is an increasingly important problem, but so far t

Erik Bernhardsson 3.2k Jan 03, 2023
TensorFlow, PyTorch and Numpy layers for generating Orthogonal Polynomials

OrthNet TensorFlow, PyTorch and Numpy layers for generating multi-dimensional Orthogonal Polynomials 1. Installation 2. Usage 3. Polynomials 4. Base C

Chuan 29 May 25, 2022
PyTorch version implementation of DORN

DORN_PyTorch This is a PyTorch version implementation of DORN Reference H. Fu, M. Gong, C. Wang, K. Batmanghelich and D. Tao: Deep Ordinal Regression

Zilin.Zhang 3 Apr 27, 2022
Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES)

Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES) This repo contains the full NITRATES pipeline for maximum likelihood-driven discov

13 Nov 08, 2022