Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Overview

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256
)

x = torch.randint(0, 20000, (1, 256))
emb = model(x) # (1, 256, 512)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Custom image sizes?

    Custom image sizes?

    Hi, Thanks for your great (and very fast) contribution! I was wondering if you could help me figure out how to apply this to a different image size? It's not really an image, but rather a 2D dimensional tensor of 4096X100.

    I saw that I can change the number of channels, so I could just set channels to be 1. But I see that firstly - your implementation is for squared images, and secondly, it requires that image size should be devisable by patch size.

    Since you've written this implementation perhaps you could help me to adapt it for my needs? (and maybe other users for their cases).

    Maybe I could pad the length to be 128 so both would be devisable by 16 for example? but then where do I set different h, w ?

    Thanks.

    opened by danarte 3
  • Parameter count doesnt line up with paper

    Parameter count doesnt line up with paper

    Just a note (and correct me if I misunderstood the paper) -

    The parameter count for the Tiny gMLP doesnt line up with the param count from the paper for 30 layers and 128 dim and 6 ff_mult. Thats probably due to the doubling of parameters here - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L111

    Halving this back to dim_ff + all 3 lines here need to halve their respective dims - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L64-L66

    Then param count is roughly 5.5 M params.

    opened by titu1994 2
  • Add Support for Stochastic Depth

    Add Support for Stochastic Depth

    This PR adds support for stochastic depth, which is used in the paper for the vision experiments. I went ahead an added it to gMLP as well for completeness.

    I tried my best to match your style. Let me know if there are any problems, or if you want me to refactor anything.

    opened by mlw214 2
  • Don't you think this is more legible?

    Don't you think this is more legible?

    ` class SpatialGatingUnit(nn.Module): def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3): super().init() dim_out = dim // 2 self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
    
        self.dim_seq = dim_seq
        self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
        self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####
    
        self.act = act
    
        init_eps /= dim_seq
        #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        #nn.init.constant_(self.proj.bias, 1.)
    
    def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
        device, n = x.device, x.shape[1]
    
        res, gate = x.chunk(2, dim = -1)
        gate = self.norm(gate)
    
        weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len
    
        if self.causal:
            weight.unsqueeze(-1) # TODO
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)
            weight.squeeze(-1)# TODO
    
        gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b
    
        #gate = F.conv1d(gate, weight, bias)   # WZ + b
    
        if exists(gate_res):
            gate = gate + gate_res
    
        return self.act(gate) * res
    

    `

    opened by ZIZUN 0
  • Potentially missing the high way pass

    Potentially missing the high way pass

    Hello,

    Maybe I missed it, but would you mind pointing out where the high way pass of the gMLP block is in the code? Based on the paper, there is a high way path (addition) between the input and the output. I couldn't find it in the gMLPBlock code.

    Thank you

    opened by Vincent-Li-9701 1
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Unofficial & improved implementation of NeRF--: Neural Radiance Fields Without Known Camera Parameters

[Unofficial code-base] NeRF--: Neural Radiance Fields Without Known Camera Parameters [ Project | Paper | Official code base ] ⬅️ Thanks the original

Jianfei Guo 239 Dec 22, 2022
Stochastic Normalizing Flows

Stochastic Normalizing Flows We introduce stochasticity in Boltzmann-generating flows. Normalizing flows are exact-probability generative models that

AI4Science group, FU Berlin (Frank Noé and co-workers) 50 Dec 16, 2022
PyTorch implementation of "Representing Shape Collections with Alignment-Aware Linear Models" paper.

deep-linear-shapes PyTorch implementation of "Representing Shape Collections with Alignment-Aware Linear Models" paper. If you find this code useful i

Romain Loiseau 27 Sep 24, 2022
3rd place solution for the Weather4cast 2021 Stage 1 Challenge

weather4cast2021_Stage1 3rd place solution for the Weather4cast 2021 Stage 1 Challenge Dependencies The code can be executed from a fresh environment

5 Aug 14, 2022
Picasso: a methods for embedding points in 2D in a way that respects distances while fitting a user-specified shape.

Picasso Code to generate Picasso embeddings of any input matrix. Picasso maps the points of an input matrix to user-defined, n-dimensional shape coord

Pachter Lab 45 Dec 23, 2022
A PoC Corporation Relationship Knowledge Graph System on top of Nebula Graph.

Corp-Rel is a PoC of Corpartion Relationship Knowledge Graph System. It's built on top of the Open Source Graph Database: Nebula Graph with a dataset

Wey Gu 20 Dec 11, 2022
Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation"

SharinGAN Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation" The official project we

Koutilya PNVR 23 Oct 19, 2022
This is 2nd term discrete maths project done by UCU students that uses backtracking to solve various problems.

Backtracking Project Sponsors This is a project made by UCU students: Olha Liuba - crossword solver implementation Hanna Yershova - sudoku solver impl

Dasha 4 Oct 17, 2021
Easy to use Python camera interface for NVIDIA Jetson

JetCam JetCam is an easy to use Python camera interface for NVIDIA Jetson. Works with various USB and CSI cameras using Jetson's Accelerated GStreamer

NVIDIA AI IOT 358 Jan 02, 2023
Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective

Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective Zhengzhuo Xu, Zenghao Chai, Chun Yuan This is the PyTorch implement

Sincere 16 Dec 15, 2022
AITUS - An atomatic notr maker for CYTUS

AITUS an automatic note maker for CYTUS. 利用AI根据指定乐曲生成CYTUS游戏谱面。 效果展示:https://www

GradiusTwinbee 6 Feb 24, 2022
Google-drive-to-sqlite - Create a SQLite database containing metadata from Google Drive

google-drive-to-sqlite Create a SQLite database containing metadata from Google

Simon Willison 140 Dec 04, 2022
GeDML is an easy-to-use generalized deep metric learning library

GeDML is an easy-to-use generalized deep metric learning library

Borui Zhang 32 Dec 05, 2022
Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition

Efficient Conformer: Progressive Downsampling and Grouped Attention for Automatic Speech Recognition Official implementation of the Efficient Conforme

Maxime Burchi 145 Dec 30, 2022
PyTorch implementation of "LayoutTransformer: Layout Generation and Completion with Self-attention"

PyTorch implementation of "LayoutTransformer: Layout Generation and Completion with Self-attention" to appear in ICCV 2021

Kamal Gupta 75 Dec 23, 2022
Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".

Detecting Twenty-thousand Classes using Image-level Supervision Detic: A Detector with image classes that can use image-level labels to easily train d

Meta Research 1.3k Jan 04, 2023
Code for the tech report Toward Training at ImageNet Scale with Differential Privacy

Differentially private Imagenet training Code for the tech report Toward Training at ImageNet Scale with Differential Privacy by Alexey Kurakin, Steve

Google Research 29 Nov 03, 2022
SynNet - synthetic tree generation using neural networks

SynNet This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can s

Wenhao Gao 60 Dec 29, 2022
Train CNNs for the fruits360 data set in NTOU CS「Machine Vision」class.

CNNs fruits360 Train CNNs for the fruits360 data set in NTOU CS「Machine Vision」class. CNN on a pretrained model Build a CNN on a pretrained model, Res

Ricky Chuang 1 Mar 07, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023