Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Overview

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Reference

  • Paper URL

  • Author: Yi Tay, Dara Bahri, Donald Metzler, Da-Cheng Juan, Zhe Zhao, Che Zheng

  • Google Research

Method

model

1. Dense Synthesizer

2. Fixed Random Synthesizer

3. Random Synthesizer

4. Factorized Dense Synthesizer

5. Factorized Random Synthesizer

6. Mixture of Synthesizers

Usage

import torch

from synthesizer import Transformer, SynthesizerDense, SynthesizerRandom, FactorizedSynthesizerDense, FactorizedSynthesizerRandom, MixtureSynthesizers, get_n_params, calculate_flops


def main():
    batch_size, channel_dim, sentence_length = 2, 1024, 32
    x = torch.randn([batch_size, sentence_length, channel_dim])

    vanilla = Transformer(channel_dim)
    out, attention_map = vanilla(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(vanilla), calculate_flops(vanilla.children())
    print('vanilla, n_params: {}, flops: {}'.format(n_params, flops))

    dense_synthesizer = SynthesizerDense(channel_dim, sentence_length)
    out, attention_map = dense_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(dense_synthesizer), calculate_flops(dense_synthesizer.children())
    print('dense_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))

    random_synthesizer = SynthesizerRandom(channel_dim, sentence_length)
    out, attention_map = random_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(random_synthesizer), calculate_flops(random_synthesizer.children())
    print('random_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))

    random_synthesizer_fix = SynthesizerRandom(channel_dim, sentence_length, fixed=True)
    out, attention_map = random_synthesizer_fix(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(random_synthesizer_fix), calculate_flops(random_synthesizer_fix.children())
    print('random_synthesizer_fix, n_params: {}, flops: {}'.format(n_params, flops))

    factorized_synthesizer_random = FactorizedSynthesizerRandom(channel_dim)
    out, attention_map = factorized_synthesizer_random(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(factorized_synthesizer_random), calculate_flops(
        factorized_synthesizer_random.children())
    print('factorized_synthesizer_random, n_params: {}, flops: {}'.format(n_params, flops))

    factorized_synthesizer_dense = FactorizedSynthesizerDense(channel_dim, sentence_length)
    out, attention_map = factorized_synthesizer_dense(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(factorized_synthesizer_dense), calculate_flops(
        factorized_synthesizer_dense.children())
    print('factorized_synthesizer_dense, n_params: {}, flops: {}'.format(n_params, flops))

    mixture_synthesizer = MixtureSynthesizers(channel_dim, sentence_length)
    out, attention_map = mixture_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(mixture_synthesizer), calculate_flops(mixture_synthesizer.children())
    print('mixture_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))


if __name__ == '__main__':
    main()

Output

torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
vanilla, n_params: 3148800, flops: 3145729
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
dense_synthesizer, n_params: 1083456, flops: 1082370
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer_fix, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_random, n_params: 1066000, flops: 1064961
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_dense, n_params: 1061900, flops: 1060865
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
mixture_synthesizer, n_params: 3149824, flops: 3145729

Paper Performance

eval

Owner
Myeongjun Kim
Computer Vision Research using Deep Learning
Myeongjun Kim
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
[2021 MultiMedia] CONQUER: Contextual Query-aware Ranking for Video Corpus Moment Retrieval

CONQUER: Contexutal Query-aware Ranking for Video Corpus Moment Retreival PyTorch implementation of CONQUER: Contexutal Query-aware Ranking for Video

Hou zhijian 23 Dec 26, 2022
Code for Learning Manifold Patch-Based Representations of Man-Made Shapes, in ICLR 2021.

LearningPatches | Webpage | Paper | Video Learning Manifold Patch-Based Representations of Man-Made Shapes Dmitriy Smirnov, Mikhail Bessmeltsev, Justi

Dima Smirnov 22 Nov 14, 2022
Repository for reproducing `Model-Based Robust Deep Learning`

Model-Based Robust Deep Learning (MBRDL) In this repository, we include the code necessary for reproducing the code used in Model-Based Robust Deep Le

Alex Robey 16 Sep 19, 2022
pytorch implementation of ABC : Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning

ABC:Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning, NeurIPS 2021 pytorch implementation of ABC : Auxiliary Balanced Class

Hyuck Lee 25 Dec 22, 2022
This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

Quinn Herden 1 Feb 04, 2022
Source code of generalized shuffled linear regression

Generalized-Shuffled-Linear-Regression Code for the ICCV 2021 paper: Generalized Shuffled Linear Regression. Authors: Feiran Li, Kent Fujiwara, Fumio

FEI 7 Oct 26, 2022
OpenAi's gym environment wrapper to vectorize them with Ray

Ray Vector Environment Wrapper You would like to use Ray to vectorize your environment but you don't want to use RLLib ? You came to the right place !

Pierre TASSEL 15 Nov 10, 2022
Let's create a tool to convert Thailand budget from PDF to CSV.

thailand-budget-pdf2csv Let's create a tool to convert Thailand Government Budgeting from PDF to CSV! รวมพลัง Dev แปลงงบ จาก PDF สู่ Machine-readable

Kao.Geek 88 Dec 19, 2022
2nd solution of ICDAR 2021 Competition on Scientific Literature Parsing, Task B.

TableMASTER-mmocr Contents About The Project Method Description Dependency Getting Started Prerequisites Installation Usage Data preprocess Train Infe

Jianquan Ye 298 Dec 21, 2022
MASS (Mueen's Algorithm for Similarity Search) - a python 2 and 3 compatible library used for searching time series sub-sequences under z-normalized Euclidean distance for similarity.

Introduction MASS allows you to search a time series for a subquery resulting in an array of distances. These array of distances enable you to identif

Matrix Profile Foundation 79 Dec 31, 2022
[ICML 2021] "Graph Contrastive Learning Automated" by Yuning You, Tianlong Chen, Yang Shen, Zhangyang Wang

Graph Contrastive Learning Automated PyTorch implementation for Graph Contrastive Learning Automated [talk] [poster] [appendix] Yuning You, Tianlong C

Shen Lab at Texas A&M University 80 Nov 23, 2022
A Pytorch Implementation of ClariNet

ClariNet A Pytorch Implementation of ClariNet (Mel Spectrogram -- Waveform) Requirements PyTorch 0.4.1 & python 3.6 & Librosa Examples Step 1. Downlo

Sungwon Kim 286 Sep 15, 2022
Our CIKM21 Paper "Incorporating Query Reformulating Behavior into Web Search Evaluation"

Reformulation-Aware-Metrics Introduction This codebase contains source-code of the Python-based implementation of our CIKM 2021 paper. Chen, Jia, et a

xuanyuan14 5 Mar 05, 2022
Code for EMNLP 2021 paper: "Learning Implicit Sentiment in Aspect-based Sentiment Analysis with Supervised Contrastive Pre-Training"

SCAPT-ABSA Code for EMNLP2021 paper: "Learning Implicit Sentiment in Aspect-based Sentiment Analysis with Supervised Contrastive Pre-Training" Overvie

Zhengyan Li 66 Dec 04, 2022
Final project code: Implementing MAE with downscaled encoders and datasets, for ESE546 FA21 at University of Pennsylvania

546 Final Project: Masked Autoencoder Haoran Tang, Qirui Wu 1. Training To train the network, please run mae_pretraining.py. Please modify folder path

Haoran Tang 0 Apr 22, 2022
Hysterese plugin with two temperature offset areas

craftbeerpi4 plugin OffsetHysterese Temperatur-Steuerungs-Plugin mit zwei tempereaturbereich abhängigen Offsets. Installation sudo pip3 install https:

HappyHibo 1 Dec 21, 2021
Bulk2Space is a spatial deconvolution method based on deep learning frameworks

Bulk2Space Spatially resolved single-cell deconvolution of bulk transcriptomes using Bulk2Space Bulk2Space is a spatial deconvolution method based on

Dr. FAN, Xiaohui 60 Dec 27, 2022
E-RAFT: Dense Optical Flow from Event Cameras

E-RAFT: Dense Optical Flow from Event Cameras This is the code for the paper E-RAFT: Dense Optical Flow from Event Cameras by Mathias Gehrig, Mario Mi

Robotics and Perception Group 71 Dec 12, 2022
A Pytorch Implementation of [Source data‐free domain adaptation of object detector through domain

A Pytorch Implementation of Source data‐free domain adaptation of object detector through domain‐specific perturbation Please follow Faster R-CNN and

1 Dec 25, 2021