An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Overview

Transformer-in-Transformer Twitter

PyPI Open In Colab Upload Python Package Lint Code Base Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformer uses pixel level attention paired with patch level attention for image classification, in TensorFlow.

PyTorch Implementation

Installation

Run the following to install:

pip install tnt-tensorflow

Developing tnt-tensorflow

To install tnt-tensorflow, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork

cd tnt
pip install -e .[dev]

Usage

import tensorflow as tf
from tnt import TNT

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)

Want to Contribute πŸ™‹β€β™‚οΈ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? πŸ’¬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citation

@misc{han2021transformer,
      title={Transformer in Transformer}, 
      author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
      year={2021},
      eprint={2103.00112},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Comments
  • Add Unit Tests

    Add Unit Tests

    The tests should check for the rank and shape of the output tensors, the test should override tf.test.TestCase base class.

    • [x] #15
    • [x] #16
    • [x] #18
    • [x] #17

    Feel free to take inspiration from:

    • https://github.com/Rishit-dagli/Fast-Transformer/blob/main/fast_transformer/test_fast_transformer.py
    • For parametrization feel free to follow https://stackoverflow.com/a/34094/11878567, can be used in the exact same way with subTest in TensorFlow
    enhancement good first issue 
    opened by Rishit-dagli 3
  • Update Workflows to run tests

    Update Workflows to run tests

    This issue follows #11

    Update GitHub Workflows to:

    • [ ] Run Tests before uploading to PyPI
    • [ ] Create a workflow to run tests on commits

    Feel free to take inspiration from https://github.com/Rishit-dagli/Fast-Transformer/tree/main/.github/workflows

    enhancement good first issue 
    opened by Rishit-dagli 0
  • Creates an Attention layer

    Creates an Attention layer

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    

    Closes #3

    opened by Rishit-dagli 0
  • Put together a TNT class

    Put together a TNT class

    Verify shapes:

    tnt = TNT(
        image_size=256,  # size of image
        patch_dim=512,  # dimension of patch token
        pixel_dim=24,  # dimension of pixel token
        patch_size=16,  # patch size
        pixel_size=4,  # pixel size
        depth=5,  # depth
        num_classes=1000,  # output number of classes
        attn_dropout=0.1,  # attention dropout
        ff_dropout=0.1,  # feedforward dropout
    )
    
    img = tf.random.uniform(shape=[1, 3, 256, 256])
    print(tnt(img).shape)
    
    # (1, 1000)
    ```
    opened by Rishit-dagli 0
  • Create an Attention layerr

    Create an Attention layerr

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    
    opened by Rishit-dagli 0
  • Create a PreNorm layer

    Create a PreNorm layer

    Verify output shapes from this layer:

    import tensorflow as tf
    PreNorm(dim=1, fn=tf.keras.layers.Dense(5))(tf.random.normal([10, 1]))
    
    # <tf.Tensor: shape=(10, 1), dtype=float32,
    
    opened by Rishit-dagli 0
Releases(v0.2.0)
  • v0.2.0(Feb 2, 2022)

    This is an interesting release for the project, including a pre-trained model on ImageNet, reproducibility of paper results, tests, and end-to-end training.

    βœ… Bug Fixes / Improvements

    • Create an end-to-end training example demonstrating how to train a TNT model for image classification through a custom training loop on the TF Flowers dataset (#14)
    • Pre-trained model to reproduce the paper results have been made available (in this release as well as on TensorFlow Hub)
    • Create an off-the-shelf inference example, that highlights how you can directly use the pre-trained model made available
    • Unit Tests for the Attention class (#19)
    • Unit Tests for the main TNT class (#20)

    Full Changelog: https://github.com/Rishit-dagli/Transformer-in-Transformer/compare/v0.1.0...v0.2.0

    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
  • v0.1.0(Dec 3, 2021)

    This is the initial release of TNT TensorFlow and implements Transformers in Transformers as a subclassed TensorFlow model.

    Classes

    • Attention: Implements attention as a TensorFlow Keras Layer making some modifications.
    • PreNorm: Normalize the activations of the previous layer for each given example in a batch independently and apply some function to it, implemented as a TensorFlow Keras Layer.
    • FeedForward: Create a FeedForward neural net with two Dense layers and GELU activation, implemented as a TensorFlow Keras Layer.
    • TNT: Implements the Transformers in Transformers model using all the other classes, and converts to logits. Implemented as a TensorFlow Keras Model.
    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
RAMA: Rapid algorithm for multicut problem

RAMA: Rapid algorithm for multicut problem Solves multicut (correlation clustering) problems orders of magnitude faster than CPU based solvers without

Paul Swoboda 60 Dec 13, 2022
HiFi++: a Unified Framework for Neural Vocoding, Bandwidth Extension and Speech Enhancement

HiFi++ : a Unified Framework for Neural Vocoding, Bandwidth Extension and Speech Enhancement This is the unofficial implementation of Vocoder part of

Rishikesh (ΰ€‹ΰ€·ΰ€Ώΰ€•ΰ₯‡ΰ€Ά) 118 Dec 29, 2022
πŸ… Top 5% in 제2회 μ—°κ΅¬κ°œλ°œνŠΉκ΅¬ 인곡지λŠ₯ κ²½μ§„λŒ€νšŒ AI SPARK μ±Œλ¦°μ§€

AI_SPARK_CHALLENG_Object_Detection 제2회 μ—°κ΅¬κ°œλ°œνŠΉκ΅¬ 인곡지λŠ₯ κ²½μ§„λŒ€νšŒ AI SPARK μ±Œλ¦°μ§€ πŸ… Top 5% in mAP(0.75) (443λͺ… 쀑 13λ“±, mAP: 0.98116) λŒ€νšŒ μ„€λͺ… Edge ν™˜κ²½μ—μ„œμ˜ κ°€μΆ• Object Dete

3 Sep 19, 2022
Generating synthetic mobility data for a realistic population with RNNs to improve utility and privacy

lbs-data Motivation Location data is collected from the public by private firms via mobile devices. Can this data also be used to serve the public goo

Alex 11 Sep 22, 2022
This is the official pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering" on VQA Task

🌈 ERASOR (RA-L'21 with ICRA Option) Official page of "ERASOR: Egocentric Ratio of Pseudo Occupancy-based Dynamic Object Removal for Static 3D Point C

Hyungtae Lim 225 Dec 29, 2022
The implementation of CVPR2021 paper Temporal Query Networks for Fine-grained Video Understanding, by Chuhan Zhang, Ankush Gupta and Andrew Zisserman.

Temporal Query Networks for Fine-grained Video Understanding πŸ“‹ This repository contains the implementation of CVPR2021 paper Temporal_Query_Networks

55 Dec 21, 2022
This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems

Stability Audit This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems, Humantic

Data, Responsibly 4 Oct 27, 2022
Dataset for the Research2Clinics @ NeurIPS 2021 Paper: What Do You See in this Patient? Behavioral Testing of Clinical NLP Models

Behavioral Testing of Clinical NLP Models This repository contains code for testing the behavior of clinical prediction models based on patient letter

Betty van Aken 2 Sep 20, 2022
Official PyTorch Implementation of Convolutional Hough Matching Networks, CVPR 2021 (oral)

Convolutional Hough Matching Networks This is the implementation of the paper "Convolutional Hough Matching Network" by J. Min and M. Cho. Implemented

Juhong Min 70 Nov 22, 2022
CR-Fill: Generative Image Inpainting with Auxiliary Contextual Reconstruction. ICCV 2021

crfill Usage | Web App | | Paper | Supplementary Material | More results | code for paper ``CR-Fill: Generative Image Inpainting with Auxiliary Contex

182 Dec 20, 2022
ICLR 2021, Fair Mixup: Fairness via Interpolation

Fair Mixup: Fairness via Interpolation Training classifiers under fairness constraints such as group fairness, regularizes the disparities of predicti

Ching-Yao Chuang 49 Nov 22, 2022
Official PyTorch implementation of SyntaSpeech (IJCAI 2022)

SyntaSpeech: Syntax-Aware Generative Adversarial Text-to-Speech | | | | δΈ­ζ–‡ζ–‡ζ‘£ This repository is the official PyTorch implementation of our IJCAI-2022

Zhenhui YE 116 Nov 24, 2022
TSP: Temporally-Sensitive Pretraining of Video Encoders for Localization Tasks

TSP: Temporally-Sensitive Pretraining of Video Encoders for Localization Tasks [Paper] [Project Website] This repository holds the source code, pretra

Humam Alwassel 83 Dec 21, 2022
Rapid experimentation and scaling of deep learning models on molecular and crystal graphs.

LitMatter A template for rapid experimentation and scaling deep learning models on molecular and crystal graphs. How to use Clone this repository and

Nathan Frey 32 Dec 06, 2022
Implemented fully documented Particle Swarm Optimization algorithm (basic model with few advanced features) using Python programming language

Implemented fully documented Particle Swarm Optimization (PSO) algorithm in Python which includes a basic model along with few advanced features such as updating inertia weight, cognitive, social lea

9 Nov 29, 2022
TensorRT examples (Jetson, Python/C++)(object detection)

TensorRT examples (Jetson, Python/C++)(object detection)

Nobuo Tsukamoto 53 Dec 22, 2022
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 08, 2022
SARS-Cov-2 Recombinant Finder for fasta sequences

Sc2rf - SARS-Cov-2 Recombinant Finder Pronounced: Scarf What's this? Sc2rf can search genome sequences of SARS-CoV-2 for potential recombinants - new

Lena Schimmel 41 Oct 03, 2022
Paper: De-rendering Stylized Texts

Paper: De-rendering Stylized Texts Wataru Shimoda1, Daichi Haraguchi2, Seiichi Uchida2, Kota Yamaguchi1 1CyberAgent.Inc, 2 Kyushu University Accepted

CyberAgent AI Lab 55 Dec 18, 2022
Convolutional Neural Network for Text Classification in Tensorflow

This code belongs to the "Implementing a CNN for Text Classification in Tensorflow" blog post. It is slightly simplified implementation of Kim's Convo

Denny Britz 5.5k Jan 02, 2023