Implementation of the state-of-the-art vision transformers with tensorflow

Overview

ViT Tensorflow

This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision models first introduced in An Image is worth 16 x 16 words). This repository is inspired from the work of lucidrains which is vit-pytorch. I hope you enjoy these implementations :)

Models

Requirements

pip install tensorflow

Vision Transformer

Vision transformer was introduced in An Image is worth 16 x 16 words. This model uses a Transformer encoder to classify images with pure attention and no convolution.

Usage

Defining the Model

from vit import ViT
import tensorflow as tf

vitClassifier = ViT(
                    num_classes=1000,
                    patch_size=16,
                    num_of_patches=(224//16)**2,
                    d_model=128,
                    heads=2,
                    num_layers=4,
                    mlp_rate=2,
                    dropout_rate=0.1,
                    prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = vitClassifier(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

vitClassifier.compile(
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              metrics=[
                       tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
              ])

vitClassifier.fit(
              trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
              validation_data=valData, #The same as training
              epochs=100,)

Convolutional Vision Transformer

Convolutional Vision Transformer was introduced in here. This model uses a hierarchical (multi-stage) architecture with convolutional embeddings in the begining of each stage. it also uses Convolutional Transformer Blocks to improve the orginal vision transformer by adding CNNs inductive bias into the architecture.

Usage

Defining the Model

from cvt import CvT , CvTStage
import tensorflow as tf

cvtModel = CvT(
num_of_classes=1000, 
stages=[
        CvTStage(projectionDim=64, 
                 heads=1, 
                 embeddingWindowSize=(7 , 7), 
                 embeddingStrides=(4 , 4), 
                 layers=1,
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=192,
                 heads=3,
                 embeddingWindowSize=(3 , 3), 
                 embeddingStrides=(2 , 2),
                 layers=1, 
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=384,
                 heads=6,
                 embeddingWindowSize=(3 , 3),
                 embeddingStrides=(2 , 2),
                 layers=1,
                 projectionWindowSize=(3 , 3),
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1)
],
dropout=0.5)
CvT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of CvTStage
    list of cvt stages
  • dropout: float
    dropout rate used for the prediction head
CvTStage Params
  • projectionDim: int
    dimension used for the multi-head attention mechanism and the convolutional embedding
  • heads: int
    number of heads in the multi-head attention mechanism
  • embeddingWindowSize: tuple(int , int)
    window size used for the convolutional emebdding
  • embeddingStrides: tuple(int , int)
    strides used for the convolutional embedding
  • layers: int
    number of convolutional transformer blocks
  • projectionWindowSize: tuple(int , int)
    window size used for the convolutional projection in each convolutional transformer block
  • projectionStrides: tuple(int , int)
    strides used for the convolutional projection in each convolutional transformer block
  • ffnRate: int
    expansion rate of the mlp block in each convolutional transformer block
  • dropoutRate: float
    dropout rate used in each convolutional transformer block

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = cvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

cvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

cvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V1

Pyramid Vision Transformer V1 was introduced in here. This model stacks multiple Transformer Encoders to form the first convolution-free multi-scale backbone for various visual tasks including Image Segmentation , Object Detection and etc. In addition to this a new attention mechanism called Spatial Reduction Attention (SRA) is also introduced in this paper to reduce the quadratic complexity of the multi-head attention mechansim.

Usage

Defining the Model

from pvt_v1 import PVT , PVTStage
import tensorflow as tf

pvtModel = PVT(
num_of_classes=1000, 
stages=[
        PVTStage(d_model=64,
                 patch_size=(2 , 2),
                 heads=1,
                 reductionFactor=2,
                 mlp_rate=2,
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=128,
                 patch_size=(2 , 2),
                 heads=2, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=320,
                 patch_size=(2 , 2),
                 heads=5, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTStage
    list of pvt stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the SRA mechanism and the patch embedding
  • patch_size: tuple(int , int)
    window size used for the patch emebdding
  • heads: int
    number of heads in the SRA mechanism
  • reductionFactor: int
    reduction factor used for the down sampling of the K and V in the SRA mechanism
  • mlp_rate: int
    expansion rate used in the feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

pvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V2

Pyramid Vision Transformer V2 was introduced in here. This model is an improved version of the PVT V1. The improvements of this version are as follows:

  1. It uses overlapping patch embedding by using padded convolutions
  2. It uses convolutional feed-forward blocks which have a depth-wise convolution after the first fully-connected layer
  3. It uses a fixed pooling instead of convolutions for down sampling the K and V in the SRA attention mechanism (The new attention mechanism is called Linear SRA)

Usage

Defining the Model

from pvt_v2 import PVTV2 , PVTV2Stage
import tensorflow as tf

pvtV2Model = PVTV2(
num_of_classes=1000, 
stages=[
        PVTV2Stage(d_model=64,
                   windowSize=(2 , 2), 
                   heads=1,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
        PVTV2Stage(d_model=128, 
                   windowSize=(2 , 2),
                   heads=2,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2,
                   dropout_rate=0.1),
        PVTV2Stage(d_model=320,
                   windowSize=(2 , 2), 
                   heads=5, 
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTV2Stage
    list of pvt v2 stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the Linear SRA mechanism and the convolutional patch embedding
  • windowSize: tuple(int , int)
    window size used for the convolutional patch emebdding
  • heads: int
    number of heads in the Linear SRA mechanism
  • poolingSize: tuple(int , int)
    size of the K and V after the fixed pooling
  • mlp_rate: int
    expansion rate used in the convolutional feed-forward block
  • mlp_windowSize: tuple(int , int)
    the window size used for the depth-wise convolution in the convolutional feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtV2Model(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtV2Model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
          metrics=[
                   tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
          ])

pvtV2Model.fit(
          trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
          validation_data=valData, #The same as training
          epochs=100,)

DeiT

DeiT was introduced in Training Data-Efficient Image Transformers & Distillation Through Attention. Since original vision transformer is data hungry due to the lack of existance of any inductive bias (unlike CNNs) a lot of data is required to train original vision transformer in order to surpass the state-of-the-art CNNs such as Resnet. Therefore, in this paper authors used a pre-trained CNN such as resent during training and used a sepcial loss function to perform distillation through attention.

Usage

Defining the Model

from deit import DeiT
import tensorflow as tf

teacherModel = tf.keras.applications.ResNet50(include_top=True, 
                                              weights="imagenet", 
                                              input_shape=(224 , 224 , 3))

deitModel = DeiT(
                 num_classes=1000,
                 patch_size=16,
                 num_of_patches=(224//16)**2,
                 d_model=128,
                 heads=2,
                 num_layers=4,
                 mlp_rate=2,
                 teacherModel=teacherModel,
                 temperature=1.0, 
                 alpha=0.5,
                 hard=False, 
                 dropout_rate=0.1,
                 prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • teacherModel: Tensorflow Model
    the teacherModel used for the distillation during training, This model is a pre-trained CNN model with the same input_shape and output_shape as the Transformer
  • temperature: float
    the temperature parameter in the loss
  • alpha: float
    the coefficient balancing the Kullback–Leibler divergence loss (KL) and the cross-entropy loss
  • hard: bool
    indicates using Hard-label distillation or Soft distillation
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = deitModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

#Note that the loss is defined inside the model and no loss should be passed here
deitModel.compile(
         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
         metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                  tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
         ])

deitModel.fit(
         trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b , num_classes))
         validation_data=valData, #The same as training
         epochs=100,)
Owner
Mohammadmahdi NouriBorji
Mohammadmahdi NouriBorji
TextureGAN in Pytorch

TextureGAN This code is our PyTorch implementation of TextureGAN [Project] [Arxiv] TextureGAN is a generative adversarial network conditioned on sketc

Patsorn 147 Dec 14, 2022
Deep Learning Emotion decoding using EEG data from Autism individuals

Deep Learning Emotion decoding using EEG data from Autism individuals This repository includes the python and matlab codes using for processing EEG 2D

Juan Manuel Mayor Torres 12 Dec 08, 2022
Code base for the paper "Scalable One-Pass Optimisation of High-Dimensional Weight-Update Hyperparameters by Implicit Differentiation"

This repository contains code for the paper Scalable One-Pass Optimisation of High-Dimensional Weight-Update Hyperparameters by Implicit Differentiati

8 Aug 28, 2022
An example of Scatterbrain implementation (combining local attention and Performer)

An example of Scatterbrain implementation (combining local attention and Performer)

HazyResearch 97 Jan 02, 2023
Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Active Learning for Deep Object Detection via Probabilistic Modeling This repository is the official PyTorch implementation of Active Learning for Dee

NVIDIA Research Projects 130 Jan 06, 2023
Finite Element Analysis

FElupe - Finite Element Analysis FElupe is a Python 3.6+ finite element analysis package focussing on the formulation and numerical solution of nonlin

Andreas D. 20 Jan 09, 2023
Generating Anime Images by Implementing Deep Convolutional Generative Adversarial Networks paper

AnimeGAN - Deep Convolutional Generative Adverserial Network PyTorch implementation of DCGAN introduced in the paper: Unsupervised Representation Lear

Rohit Kukreja 23 Jul 21, 2022
Code for reproducing experiments in "Improved Training of Wasserstein GANs"

Improved Training of Wasserstein GANs Code for reproducing experiments in "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, Tensor

Ishaan Gulrajani 2.2k Jan 01, 2023
Pytorch Implementation of PointNet and PointNet++++

Pytorch Implementation of PointNet and PointNet++ This repo is implementation for PointNet and PointNet++ in pytorch. Update 2021/03/27: (1) Release p

Luigi Ariano 1 Nov 11, 2021
TorchOk - The toolkit for fast Deep Learning experiments in Computer Vision

TorchOk - The toolkit for fast Deep Learning experiments in Computer Vision

52 Dec 23, 2022
Tackling Obstacle Tower Challenge using PPO & A2C combined with ICM.

Obstacle Tower Challenge using Deep Reinforcement Learning Unity Obstacle Tower is a challenging realistic 3D, third person perspective and procedural

Zhuoyu Feng 5 Feb 10, 2022
PyTorch Personal Trainer: My framework for deep learning experiments

Alex's PyTorch Personal Trainer (ptpt) (name subject to change) This repository contains my personal lightweight framework for deep learning projects

Alex McKinney 8 Jul 14, 2022
PyTorch implementation of the end-to-end coreference resolution model with different higher-order inference methods.

End-to-End Coreference Resolution with Different Higher-Order Inference Methods This repository contains the implementation of the paper: Revealing th

Liyan 52 Jan 04, 2023
Semantic Segmentation of images using PixelLib with help of Pascalvoc dataset trained with Deeplabv3+ framework.

CARscan- Approach 1 - Segmentation of images by detecting contours. It failed because in images with elements along with cars were also getting detect

Padmanabha Banerjee 5 Jul 29, 2021
img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation

img2pose: Face Alignment and Detection via 6DoF, Face Pose Estimation Figure 1: We estimate the 6DoF rigid transformation of a 3D face (rendered in si

Vítor Albiero 519 Dec 29, 2022
Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training"

Saliency Guided Training Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training" by Aya Abdelsalam Ismail, Hector Cor

8 Sep 22, 2022
An open source machine learning library for performing regression tasks using RVM technique.

Introduction neonrvm is an open source machine learning library for performing regression tasks using RVM technique. It is written in C programming la

Siavash Eliasi 33 May 31, 2022
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
GeoTransformer - Geometric Transformer for Fast and Robust Point Cloud Registration

Geometric Transformer for Fast and Robust Point Cloud Registration PyTorch imple

Zheng Qin 220 Jan 05, 2023
The official homepage of the (outdated) COCO-Stuff 10K dataset.

COCO-Stuff 10K dataset v1.1 (outdated) Holger Caesar, Jasper Uijlings, Vittorio Ferrari Overview Welcome to official homepage of the COCO-Stuff [1] da

Holger Caesar 263 Dec 11, 2022