🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

Overview

A Codebase For Attention, MLP, Re-parameter(ReP), Convolution

If this project is helpful to you, welcome to give a star.

Don't forget to follow me to learn about project updates.

Installation (Optional)

For the convenience use of this project, the pip installation method is provided. You can run the following command directly:

$ pip install dlutils_add

(However, it is highly recommended that you git clone this project, because pip install may not be updated in a timely manner. .whl file can also be downloaded by BaiDuYun (Access code: c56j).)


Contents


Attention Series


1. External Attention Usage

1.1. Paper

"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks"

1.2. Overview

1.3. Code

from attention.ExternalAttention import ExternalAttention
import torch

input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)

2. Self Attention Usage

2.1. Paper

"Attention Is All You Need"

1.2. Overview

1.3. Code

from attention.SelfAttention import ScaledDotProductAttention
import torch

input=torch.randn(50,49,512)
sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)

3. Simplified Self Attention Usage

3.1. Paper

None

3.2. Overview

3.3. Code

from attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention
import torch

input=torch.randn(50,49,512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output=ssa(input,input,input)
print(output.shape)

4. Squeeze-and-Excitation Attention Usage

4.1. Paper

"Squeeze-and-Excitation Networks"

4.2. Overview

4.3. Code

from attention.SEAttention import SEAttention
import torch

input=torch.randn(50,512,7,7)
se = SEAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)

5. SK Attention Usage

5.1. Paper

"Selective Kernel Networks"

5.2. Overview

5.3. Code

from attention.SKAttention import SKAttention
import torch

input=torch.randn(50,512,7,7)
se = SKAttention(channel=512,reduction=8)
output=se(input)
print(output.shape)

6. CBAM Attention Usage

6.1. Paper

"CBAM: Convolutional Block Attention Module"

6.2. Overview

6.3. Code

from attention.CBAM import CBAMBlock
import torch

input=torch.randn(50,512,7,7)
kernel_size=input.shape[2]
cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
output=cbam(input)
print(output.shape)

7. BAM Attention Usage

7.1. Paper

"BAM: Bottleneck Attention Module"

7.2. Overview

7.3. Code

from attention.BAM import BAMBlock
import torch

input=torch.randn(50,512,7,7)
bam = BAMBlock(channel=512,reduction=16,dia_val=2)
output=bam(input)
print(output.shape)

8. ECA Attention Usage

8.1. Paper

"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks"

8.2. Overview

8.3. Code

from attention.ECAAttention import ECAAttention
import torch

input=torch.randn(50,512,7,7)
eca = ECAAttention(kernel_size=3)
output=eca(input)
print(output.shape)

9. DANet Attention Usage

9.1. Paper

"Dual Attention Network for Scene Segmentation"

9.2. Overview

9.3. Code

from attention.DANet import DAModule
import torch

input=torch.randn(50,512,7,7)
danet=DAModule(d_model=512,kernel_size=3,H=7,W=7)
print(danet(input).shape)

10. Pyramid Split Attention Usage

10.1. Paper

"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network"

10.2. Overview

10.3. Code

from attention.PSA import PSA
import torch

input=torch.randn(50,512,7,7)
psa = PSA(channel=512,reduction=8)
output=psa(input)
print(output.shape)

11. Efficient Multi-Head Self-Attention Usage

11.1. Paper

"ResT: An Efficient Transformer for Visual Recognition"

11.2. Overview

11.3. Code

from attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,64,512)
emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True)
output=emsa(input,input,input)
print(output.shape)
    

12. Shuffle Attention Usage

12.1. Paper

"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS"

12.2. Overview

12.3. Code

from attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)

    

13. MUSE Attention Usage

13.1. Paper

"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning"

13.2. Overview

13.3. Code

from attention.MUSEAttention import MUSEAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,49,512)
sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)

14. SGE Attention Usage

14.1. Paper

Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks

14.2. Overview

14.3. Code

from attention.SGE import SpatialGroupEnhance
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,512,7,7)
sge = SpatialGroupEnhance(groups=8)
output=sge(input)
print(output.shape)

15. A2 Attention Usage

15.1. Paper

A2-Nets: Double Attention Networks

15.2. Overview

15.3. Code

from attention.A2Atttention import DoubleAttention
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,512,7,7)
a2 = DoubleAttention(512,128,128,True)
output=a2(input)
print(output.shape)

16. AFT Attention Usage

16.1. Paper

An Attention Free Transformer

16.2. Overview

16.3. Code

from attention.AFT import AFT_FULL
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,49,512)
aft_full = AFT_FULL(d_model=512, n=49)
output=aft_full(input)
print(output.shape)

17. Outlook Attention Usage

17.1. Paper

VOLO: Vision Outlooker for Visual Recognition"

17.2. Overview

17.3. Code

from attention.OutlookAttention import OutlookAttention
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,28,28,512)
outlook = OutlookAttention(dim=512)
output=outlook(input)
print(output.shape)

18. ViP Attention Usage

18.1. Paper

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition"

18.2. Overview

18.3. Code

from attention.ViP import WeightedPermuteMLP
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(64,8,8,512)
seg_dim=8
vip=WeightedPermuteMLP(512,seg_dim)
out=vip(input)
print(out.shape)

19. CoAtNet Attention Usage

19.1. Paper

CoAtNet: Marrying Convolution and Attention for All Data Sizes"

19.2. Overview

None

19.3. Code

from attention.CoAtNet import CoAtNet
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,3,224,224)
mbconv=CoAtNet(in_ch=3,image_size=224)
out=mbconv(input)
print(out.shape)

20. HaloNet Attention Usage

20.1. Paper

Scaling Local Self-Attention for Parameter Efficient Visual Backbones"

20.2. Overview

20.3. Code

from attention.HaloAttention import HaloAttention
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,512,8,8)
halo = HaloAttention(dim=512,
    block_size=2,
    halo_size=1,)
output=halo(input)
print(output.shape)

21. Polarized Self-Attention Usage

21.1. Paper

Polarized Self-Attention: Towards High-quality Pixel-wise Regression"

21.2. Overview

21.3. Code

from attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,512,7,7)
psa = SequentialPolarizedSelfAttention(channel=512)
output=psa(input)
print(output.shape)

22. CoTAttention Usage

22.1. Paper

Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26

22.2. Overview

22.3. Code

from attention.CoTAttention import CoTAttention
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,512,7,7)
cot = CoTAttention(dim=512,kernel_size=3)
output=cot(input)
print(output.shape)


MLP Series

1. RepMLP Usage

1.1. Paper

"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition"

1.2. Overview

1.3. Code

from mlp.repmlp import RepMLP
import torch
from torch import nn

N=4 #batch size
C=512 #input dim
O=1024 #output dim
H=14 #image height
W=14 #image width
h=7 #patch height
w=7 #patch width
fc1_fc2_reduction=1 #reduction ratio
fc3_groups=8 # groups
repconv_kernels=[1,3,5,7] #kernel list
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
x=torch.randn(N,C,H,W)
repmlp.eval()
for module in repmlp.modules():
    if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
        nn.init.uniform_(module.running_mean, 0, 0.1)
        nn.init.uniform_(module.running_var, 0, 0.1)
        nn.init.uniform_(module.weight, 0, 0.1)
        nn.init.uniform_(module.bias, 0, 0.1)

#training result
out=repmlp(x)
#inference result
repmlp.switch_to_deploy()
deployout = repmlp(x)

print(((deployout-out)**2).sum())

2. MLP-Mixer Usage

2.1. Paper

"MLP-Mixer: An all-MLP Architecture for Vision"

2.2. Overview

2.3. Code

from mlp.mlp_mixer import MlpMixer
import torch
mlp_mixer=MlpMixer(num_classes=1000,num_blocks=10,patch_size=10,tokens_hidden_dim=32,channels_hidden_dim=1024,tokens_mlp_dim=16,channels_mlp_dim=1024)
input=torch.randn(50,3,40,40)
output=mlp_mixer(input)
print(output.shape)

3. ResMLP Usage

3.1. Paper

"ResMLP: Feedforward networks for image classification with data-efficient training"

3.2. Overview

3.3. Code

from mlp.resmlp import ResMLP
import torch

input=torch.randn(50,3,14,14)
resmlp=ResMLP(dim=128,image_size=14,patch_size=7,class_num=1000)
out=resmlp(input)
print(out.shape) #the last dimention is class_num

4. gMLP Usage

4.1. Paper

"Pay Attention to MLPs"

4.2. Overview

4.3. Code

from mlp.g_mlp import gMLP
import torch

num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim=512,d_ff=1024)
output=gmlp(input)
print(output.shape)

Re-Parameter Series


1. RepVGG Usage

1.1. Paper

"RepVGG: Making VGG-style ConvNets Great Again"

1.2. Overview

1.3. Code

from rep.repvgg import RepBlock
import torch


input=torch.randn(50,512,49,49)
repblock=RepBlock(512,512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print('difference between vgg and repvgg')
print(((out2-out)**2).sum())

2. ACNet Usage

2.1. Paper

"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks"

2.2. Overview

2.3. Code

from rep.acnet import ACNet
import torch
from torch import nn

input=torch.randn(50,512,49,49)
acnet=ACNet(512,512)
acnet.eval()
out=acnet(input)
acnet._switch_to_deploy()
out2=acnet(input)
print('difference:')
print(((out2-out)**2).sum())

2. Diverse Branch Block Usage

2.1. Paper

"Diverse Branch Block: Building a Convolution as an Inception-like Unit"

2.2. Overview

2.3. Code

2.3.1 Transform I
from rep.ddb import transI_conv_bn
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,64,7,7)
#conv+bn
conv1=nn.Conv2d(64,64,3,padding=1)
bn1=nn.BatchNorm2d(64)
bn1.eval()
out1=bn1(conv1(input))

#conv_fuse
conv_fuse=nn.Conv2d(64,64,3,padding=1)
conv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)
out2=conv_fuse(input)

print("difference:",((out2-out1)**2).sum().item())
2.3.2 Transform II
from rep.ddb import transII_conv_branch
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,64,7,7)

#conv+conv
conv1=nn.Conv2d(64,64,3,padding=1)
conv2=nn.Conv2d(64,64,3,padding=1)
out1=conv1(input)+conv2(input)

#conv_fuse
conv_fuse=nn.Conv2d(64,64,3,padding=1)
conv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)
out2=conv_fuse(input)

print("difference:",((out2-out1)**2).sum().item())
2.3.3 Transform III
from rep.ddb import transIII_conv_sequential
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,64,7,7)

#conv+conv
conv1=nn.Conv2d(64,64,1,padding=0,bias=False)
conv2=nn.Conv2d(64,64,3,padding=1,bias=False)
out1=conv2(conv1(input))


#conv_fuse
conv_fuse=nn.Conv2d(64,64,3,padding=1,bias=False)
conv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)
out2=conv_fuse(input)

print("difference:",((out2-out1)**2).sum().item())
2.3.4 Transform IV
from rep.ddb import transIV_conv_concat
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,64,7,7)

#conv+conv
conv1=nn.Conv2d(64,32,3,padding=1)
conv2=nn.Conv2d(64,32,3,padding=1)
out1=torch.cat([conv1(input),conv2(input)],dim=1)

#conv_fuse
conv_fuse=nn.Conv2d(64,64,3,padding=1)
conv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)
out2=conv_fuse(input)

print("difference:",((out2-out1)**2).sum().item())
2.3.5 Transform V
from rep.ddb import transV_avg
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,64,7,7)

avg=nn.AvgPool2d(kernel_size=3,stride=1)
out1=avg(input)

conv=transV_avg(64,3)
out2=conv(input)

print("difference:",((out2-out1)**2).sum().item())
2.3.6 Transform VI
from rep.ddb import transVI_conv_scale
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,64,7,7)

#conv+conv
conv1x1=nn.Conv2d(64,64,1)
conv1x3=nn.Conv2d(64,64,(1,3),padding=(0,1))
conv3x1=nn.Conv2d(64,64,(3,1),padding=(1,0))
out1=conv1x1(input)+conv1x3(input)+conv3x1(input)

#conv_fuse
conv_fuse=nn.Conv2d(64,64,3,padding=1)
conv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)
out2=conv_fuse(input)

print("difference:",((out2-out1)**2).sum().item())

Convolution Series


1. Depthwise Separable Convolution Usage

1.1. Paper

"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"

1.2. Overview

1.3. Code

from conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolution
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,3,224,224)
dsconv=DepthwiseSeparableConvolution(3,64)
out=dsconv(input)
print(out.shape)

2. MBConv Usage

2.1. Paper

"Efficientnet: Rethinking model scaling for convolutional neural networks"

2.2. Overview

2.3. Code

from conv.MBConv import MBConvBlock
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,3,224,224)
mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=512,image_size=224)
out=mbconv(input)
print(out.shape)

3. Involution Usage

3.1. Paper

"Involution: Inverting the Inherence of Convolution for Visual Recognition"

3.2. Overview

3.3. Code

from conv.Involution import Involution
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(1,4,64,64)
involution=Involution(kernel_size=3,in_channel=4,stride=2)
out=involution(input)
print(out.shape)

Owner
xmu-xiaoma66
A graduate student in MAC Lab of XMU
xmu-xiaoma66
Forecasting for knowable future events using Bayesian informative priors (forecasting with judgmental-adjustment).

What is judgyprophet? judgyprophet is a Bayesian forecasting algorithm based on Prophet, that enables forecasting while using information known by the

AstraZeneca 56 Oct 26, 2022
CNNs for Sentence Classification in PyTorch

Introduction This is the implementation of Kim's Convolutional Neural Networks for Sentence Classification paper in PyTorch. Kim's implementation of t

Shawn Ng 956 Dec 19, 2022
MG-GCN: Scalable Multi-GPU GCN Training Framework

MG-GCN MG-GCN: multi-GPU GCN training framework. For more information, please read our paper. After cloning our repository, run git submodule update -

Translational Data Analytics (TDA) Lab @GaTech 6 Oct 24, 2022
FrankMocap: A Strong and Easy-to-use Single View 3D Hand+Body Pose Estimator

FrankMocap pursues an easy-to-use single view 3D motion capture system developed by Facebook AI Research (FAIR). FrankMocap provides state-of-the-art 3D pose estimation outputs for body, hand, and bo

Facebook Research 1.9k Jan 07, 2023
NanoDet-Plus⚡Super fast and lightweight anchor-free object detection model. 🔥Only 980 KB(int8) / 1.8MB (fp16) and run 97FPS on cellphone🔥

NanoDet-Plus⚡Super fast and lightweight anchor-free object detection model. 🔥Only 980 KB(int8) / 1.8MB (fp16) and run 97FPS on cellphone🔥

4.8k Jan 07, 2023
This code is a toolbox that uses Torch library for training and evaluating the ERFNet architecture for semantic segmentation.

ERFNet This code is a toolbox that uses Torch library for training and evaluating the ERFNet architecture for semantic segmentation. NEW!! New PyTorch

Edu 104 Jan 05, 2023
ComPhy: Compositional Physical Reasoning ofObjects and Events from Videos

ComPhy This repository holds the code for the paper. ComPhy: Compositional Physical Reasoning ofObjects and Events from Videos, (Under review) PDF Pro

29 Dec 29, 2022
Syed Waqas Zamir 906 Dec 30, 2022
BMVC 2021: This is the github repository for "Few Shot Temporal Action Localization using Query Adaptive Transformers" accepted in British Machine Vision Conference (BMVC) 2021, Virtual

FS-QAT: Few Shot Temporal Action Localization using Query Adaptive Transformer Accepted as Poster in BMVC 2021 This is an official implementation in P

Sauradip Nag 14 Dec 09, 2022
A modular, research-friendly framework for high-performance and inference of sequence models at many scales

T5X T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of

Google Research 1.1k Jan 08, 2023
Jittor implementation of PCT:Point Cloud Transformer

PCT: Point Cloud Transformer This is a Jittor implementation of PCT: Point Cloud Transformer.

MenghaoGuo 547 Jan 03, 2023
Text Summarization - WCN — Weighted Contextual N-gram method for evaluation of Text Summarization

Text Summarization WCN — Weighted Contextual N-gram method for evaluation of Text Summarization In this project, I fine tune T5 model on Extreme Summa

Aditya Shah 1 Jan 03, 2022
Probabilistic Entity Representation Model for Reasoning over Knowledge Graphs

Implementation for the paper: Probabilistic Entity Representation Model for Reasoning over Knowledge Graphs, Nurendra Choudhary, Nikhil Rao, Sumeet Ka

Nurendra Choudhary 8 Nov 15, 2022
DeepSTD: Mining Spatio-temporal Disturbances of Multiple Context Factors for Citywide Traffic Flow Prediction

DeepSTD: Mining Spatio-temporal Disturbances of Multiple Context Factors for Citywide Traffic Flow Prediction This is the implementation of DeepSTD in

5 Sep 26, 2022
We have made you a wrapper you can't refuse

We have made you a wrapper you can't refuse We have a vibrant community of developers helping each other in our Telegram group. Join us! Stay tuned fo

20.6k Jan 09, 2023
Nb workflows - A workflow platform which allows you to run parameterized notebooks programmatically

NB Workflows Description If SQL is a lingua franca for querying data, Jupyter sh

Xavier Petit 6 Aug 18, 2022
The Video-based Accident Detection System built in Python

Accident-detection-system About the Project This Repository contains the Video-based Accident Detection System built in Python. Contributors Yukta Gop

SURYAVANSHI SNEHAL BALKRISHNA 50 Dec 07, 2022
This repository contains the data and code for the paper "Diverse Text Generation via Variational Encoder-Decoder Models with Gaussian Process Priors" ([email protected])

GP-VAE This repository provides datasets and code for preprocessing, training and testing models for the paper: Diverse Text Generation via Variationa

Wanyu Du 18 Dec 29, 2022
LSTM and QRNN Language Model Toolkit for PyTorch

LSTM and QRNN Language Model Toolkit This repository contains the code used for two Salesforce Research papers: Regularizing and Optimizing LSTM Langu

Salesforce 1.9k Jan 08, 2023
PyTorch implementations for our SIGGRAPH 2021 paper: Editable Free-viewpoint Video Using a Layered Neural Representation.

st-nerf We provide PyTorch implementations for our paper: Editable Free-viewpoint Video Using a Layered Neural Representation SIGGRAPH 2021 Jiakai Zha

Diplodocus 258 Jan 02, 2023