An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Overview

Transformer-in-Transformer Twitter

PyPI Open In Colab Upload Python Package Lint Code Base Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformer uses pixel level attention paired with patch level attention for image classification, in TensorFlow.

PyTorch Implementation

Installation

Run the following to install:

pip install tnt-tensorflow

Developing tnt-tensorflow

To install tnt-tensorflow, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork

cd tnt
pip install -e .[dev]

Usage

import tensorflow as tf
from tnt import TNT

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citation

@misc{han2021transformer,
      title={Transformer in Transformer}, 
      author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
      year={2021},
      eprint={2103.00112},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Comments
  • Add Unit Tests

    Add Unit Tests

    The tests should check for the rank and shape of the output tensors, the test should override tf.test.TestCase base class.

    • [x] #15
    • [x] #16
    • [x] #18
    • [x] #17

    Feel free to take inspiration from:

    • https://github.com/Rishit-dagli/Fast-Transformer/blob/main/fast_transformer/test_fast_transformer.py
    • For parametrization feel free to follow https://stackoverflow.com/a/34094/11878567, can be used in the exact same way with subTest in TensorFlow
    enhancement good first issue 
    opened by Rishit-dagli 3
  • Update Workflows to run tests

    Update Workflows to run tests

    This issue follows #11

    Update GitHub Workflows to:

    • [ ] Run Tests before uploading to PyPI
    • [ ] Create a workflow to run tests on commits

    Feel free to take inspiration from https://github.com/Rishit-dagli/Fast-Transformer/tree/main/.github/workflows

    enhancement good first issue 
    opened by Rishit-dagli 0
  • Creates an Attention layer

    Creates an Attention layer

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    

    Closes #3

    opened by Rishit-dagli 0
  • Put together a TNT class

    Put together a TNT class

    Verify shapes:

    tnt = TNT(
        image_size=256,  # size of image
        patch_dim=512,  # dimension of patch token
        pixel_dim=24,  # dimension of pixel token
        patch_size=16,  # patch size
        pixel_size=4,  # pixel size
        depth=5,  # depth
        num_classes=1000,  # output number of classes
        attn_dropout=0.1,  # attention dropout
        ff_dropout=0.1,  # feedforward dropout
    )
    
    img = tf.random.uniform(shape=[1, 3, 256, 256])
    print(tnt(img).shape)
    
    # (1, 1000)
    ```
    opened by Rishit-dagli 0
  • Create an Attention layerr

    Create an Attention layerr

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    
    opened by Rishit-dagli 0
  • Create a PreNorm layer

    Create a PreNorm layer

    Verify output shapes from this layer:

    import tensorflow as tf
    PreNorm(dim=1, fn=tf.keras.layers.Dense(5))(tf.random.normal([10, 1]))
    
    # <tf.Tensor: shape=(10, 1), dtype=float32,
    
    opened by Rishit-dagli 0
Releases(v0.2.0)
  • v0.2.0(Feb 2, 2022)

    This is an interesting release for the project, including a pre-trained model on ImageNet, reproducibility of paper results, tests, and end-to-end training.

    ✅ Bug Fixes / Improvements

    • Create an end-to-end training example demonstrating how to train a TNT model for image classification through a custom training loop on the TF Flowers dataset (#14)
    • Pre-trained model to reproduce the paper results have been made available (in this release as well as on TensorFlow Hub)
    • Create an off-the-shelf inference example, that highlights how you can directly use the pre-trained model made available
    • Unit Tests for the Attention class (#19)
    • Unit Tests for the main TNT class (#20)

    Full Changelog: https://github.com/Rishit-dagli/Transformer-in-Transformer/compare/v0.1.0...v0.2.0

    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
  • v0.1.0(Dec 3, 2021)

    This is the initial release of TNT TensorFlow and implements Transformers in Transformers as a subclassed TensorFlow model.

    Classes

    • Attention: Implements attention as a TensorFlow Keras Layer making some modifications.
    • PreNorm: Normalize the activations of the previous layer for each given example in a batch independently and apply some function to it, implemented as a TensorFlow Keras Layer.
    • FeedForward: Create a FeedForward neural net with two Dense layers and GELU activation, implemented as a TensorFlow Keras Layer.
    • TNT: Implements the Transformers in Transformers model using all the other classes, and converts to logits. Implemented as a TensorFlow Keras Model.
    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
Rotated Box Is Back : Accurate Box Proposal Network for Scene Text Detection

Rotated Box Is Back : Accurate Box Proposal Network for Scene Text Detection This material is supplementray code for paper accepted in ICDAR 2021 We h

NCSOFT 30 Dec 21, 2022
SingleVC performs any-to-one VC, which is an important component of MediumVC project.

SingleVC performs any-to-one VC, which is an important component of MediumVC project. Here is the official implementation of the paper, MediumVC.

谷下雨 26 Dec 28, 2022
Official repository of the paper Privacy-friendly Synthetic Data for the Development of Face Morphing Attack Detectors

SMDD-Synthetic-Face-Morphing-Attack-Detection-Development-dataset Official repository of the paper Privacy-friendly Synthetic Data for the Development

10 Dec 12, 2022
Caffe models in TensorFlow

Caffe to TensorFlow Convert Caffe models to TensorFlow. Usage Run convert.py to convert an existing Caffe model to TensorFlow. Make sure you're using

Saumitro Dasgupta 2.8k Dec 31, 2022
Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021 [Projec

Zhengqi Li 583 Dec 30, 2022
Automated Melanoma Recognition in Dermoscopy Images via Very Deep Residual Networks

Introduction This repository contains the modified caffe library and network architectures for our paper "Automated Melanoma Recognition in Dermoscopy

Lequan Yu 47 Nov 24, 2022
Information-Theoretic Multi-Objective Bayesian Optimization with Continuous Approximations

Information-Theoretic Multi-Objective Bayesian Optimization with Continuous Approximations Requirements The code is implemented in Python and requires

1 Nov 03, 2021
Target Propagation via Regularized Inversion

Target Propagation via Regularized Inversion The present code implements an ideal formulation of target propagation using regularized inverses compute

Vincent Roulet 0 Dec 02, 2021
KoCLIP: Korean port of OpenAI CLIP, in Flax

KoCLIP This repository contains code for KoCLIP, a Korean port of OpenAI's CLIP. This project was conducted as part of Hugging Face's Flax/JAX communi

Jake Tae 100 Jan 02, 2023
On Evaluation Metrics for Graph Generative Models

On Evaluation Metrics for Graph Generative Models Authors: Rylee Thompson, Boris Knyazev, Elahe Ghalebi, Jungtaek Kim, Graham Taylor This is the offic

13 Jan 07, 2023
A Python package for faster, safer, and simpler ML processes

Bender 🤖 A Python package for faster, safer, and simpler ML processes. Why use bender? Bender will make your machine learning processes, faster, safe

Otovo 6 Dec 13, 2022
Fantasy Points Prediction and Dream Team Formation

Fantasy-Points-Prediction-and-Dream-Team-Formation Collected Data from open source resources that have over 100 Parameters for predicting cricket play

Akarsh Singh 2 Sep 13, 2022
Audio2Face - Audio To Face With Python

Audio2Face Discription We create a project that transforms audio to blendshape w

FACEGOOD 724 Dec 26, 2022
A tiny, friendly, strong baseline code for Person-reID (based on pytorch).

Pytorch ReID Strong, Small, Friendly A tiny, friendly, strong baseline code for Person-reID (based on pytorch). Strong. It is consistent with the new

Zhedong Zheng 3.5k Jan 08, 2023
MILK: Machine Learning Toolkit

MILK: MACHINE LEARNING TOOLKIT Machine Learning in Python Milk is a machine learning toolkit in Python. Its focus is on supervised classification with

Luis Pedro Coelho 610 Dec 14, 2022
なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモ

FaceDetection-Anti-Spoof-Demo なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモです。 モデルはPINTO_model_zoo/191_anti-spoof-mn3からONNX形式のモデルを使用しています。 Requirement mediapipe

KazuhitoTakahashi 8 Nov 18, 2022
Neuron class provides LNU (Linear Neural Unit), QNU (Quadratic Neural Unit), RBF (Radial Basis Function), MLP (Multi Layer Perceptron), MLP-ELM (Multi Layer Perceptron - Extreme Learning Machine) neurons learned with Gradient descent or LeLevenberg–Marquardt algorithm

Neuron class provides LNU (Linear Neural Unit), QNU (Quadratic Neural Unit), RBF (Radial Basis Function), MLP (Multi Layer Perceptron), MLP-ELM (Multi Layer Perceptron - Extreme Learning Machine) neu

Filip Molcik 38 Dec 17, 2022
Transferable Unrestricted Attacks, which won 1st place in CVPR’21 Security AI Challenger: Unrestricted Adversarial Attacks on ImageNet.

Transferable Unrestricted Adversarial Examples This is the PyTorch implementation of the Arxiv paper: Towards Transferable Unrestricted Adversarial Ex

equation 16 Dec 29, 2022
1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

Lihe Yang 209 Jan 01, 2023
A Python implementation of active inference for Markov Decision Processes

A Python package for simulating Active Inference agents in Markov Decision Process environments. Please see our companion preprint on arxiv for an ove

235 Dec 21, 2022