Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Overview

Memory Efficient Attention Pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
  • benchmark and see how much torch jit helps
  • look at Triton and Keops and see if either can be a fit

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • [feature request] Combining with flash attention?

    [feature request] Combining with flash attention?

    There is a new algorithm to optimize the qkv attention, https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135 It optimises the qkv attention part. Maybe you can look into integrating it with this.

    opened by Vbansal21 15
  • i did this, we could build on top

    i did this, we could build on top

    Hi there!

    It seems I did already some of the code... https://github.com/CHARM-Tx/linear_mem_attention_pytorch could we build on top of this? I talked to https://github.com/Chillee about an experimental functionality from functorch: https://github.com/pytorch/functorch that would allow for increased speed (mainly i want to match jax perofmance but its just difficult w/ pytorch imperative style).

    I would love to collaborate on this if you want!

    opened by hypnopump 5
  • Added dropout support to memory efficient variant

    Added dropout support to memory efficient variant

    Hey Phil,

    I have been using this repository for a project and I wanted to add dropout for completeness. I checked consistency with perceiver-ar impl.. I hope this is helpful.

    -Matt

    opened by usryokousha 2
  • Making this work with relative position bias from XTransformers

    Making this work with relative position bias from XTransformers

    Is there a way to make this work with RelativePositionBias. Currently this produces an attention bias of size $BHN^2$ where B is batch size, H is number of heads and N is input size. Can this be chunked and computed per chunk?

    opened by pfeatherstone 5
  •  save_for_backward can only save variables, but argument 5 is of type bool

    save_for_backward can only save variables, but argument 5 is of type bool

    Hi,

    Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.

    Can you please help me out?

    Code:

    import torch from memory_efficient_attention_pytorch import Attention

    cross_attn = Attention( dim = 512, dim_head = 64, heads = 8, memory_efficient = True, q_bucket_size = 1024, k_bucket_size = 2048 ).cuda() (# out = sm_mod(inp1)) did this to avoid being a header x = torch.randn(1, 65536, 512).cuda() context = torch.randn(1, 65536, 512).cuda() (# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading out = cross_attn(x

    ERROR:

    File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in cli.main() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main run() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file runpy.run_path(target_as_str, run_name=compat.force_str("main")) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path return _run_module_code(code, init_globals, run_name, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out) File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint return CheckpointFunction.apply(function, preserve, *args) TypeError: save_for_backward can only save variables, but argument 5 is of type bool

    opened by aliabid2243 1
  • Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/35559a05572f9d4eb982a8e2e399b40a2d61b85c/memory_efficient_attention_pytorch/memory_efficient_attention.py#L95

    Should this be: summarize_qkv_fn = summarize_qkv_chunk if needs_backwards else checkpointed_summarize_qkv_chunk instead of: summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    opened by vrobot 0
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
pytorch implementation of Attention is all you need

A Pytorch Implementation of the Transformer: Attention Is All You Need Our implementation is largely based on Tensorflow implementation Requirements N

230 Dec 07, 2022
SpecAugmentPyTorch - A Pytorch (support batch and channel) implementation of GoogleBrain's SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition

SpecAugment An implementation of SpecAugment for Pytorch How to use Install pytorch, version=1.9.0 (new feature (torch.Tensor.take_along_dim) is used

IMLHF 3 Oct 11, 2022
A Deep Learning Framework for Neural Derivative Hedging

NNHedge NNHedge is a PyTorch based framework for Neural Derivative Hedging. The following repository was implemented to ease the experiments of our pa

GUIJIN SON 17 Nov 14, 2022
Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Octavio Arriaga 5.3k Dec 30, 2022
Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)'

SCL Introduction Code for 'Self-Guided and Cross-Guided Learning for Few-shot segmentation. (CVPR' 2021)' We evaluated our approach using two baseline

34 Oct 08, 2022
Implementation of [Time in a Box: Advancing Knowledge Graph Completion with Temporal Scopes].

Time2box Implementation of [Time in a Box: Advancing Knowledge Graph Completion with Temporal Scopes].

LingCai 4 Aug 23, 2022
Code and Data for NeurIPS2021 Paper "A Dataset for Answering Time-Sensitive Questions"

Time-Sensitive-QA The repo contains the dataset and code for NeurIPS2021 (dataset track) paper Time-Sensitive Question Answering dataset. The dataset

wenhu chen 35 Nov 14, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

SMPLify-XMC This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright Lic

Lea Müller 83 Dec 14, 2022
Tree-based Search Graph for Approximate Nearest Neighbor Search

TBSG: Tree-based Search Graph for Approximate Nearest Neighbor Search. TBSG is a graph-based algorithm for ANNS based on Cover Tree, which is also an

Fanxbin 2 Dec 27, 2022
Awesome Monocular 3D detection

Awesome Monocular 3D detection Paper list of 3D detetction, keep updating! Contents Paper List 2022 2021 2020 2019 2018 2017 2016 KITTI Results Paper

Zhikang Zou 184 Jan 04, 2023
Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Label-Efficient Semantic Segmentation with Diffusion Models Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion

Yandex Research 355 Jan 06, 2023
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
Line-level Handwritten Text Recognition (HTR) system implemented with TensorFlow.

Line-level Handwritten Text Recognition with TensorFlow This model is an extended version of the Simple HTR system implemented by @Harald Scheidl and

Hoàng Tùng Lâm (Linus) 72 May 07, 2022
Dynamic Token Normalization Improves Vision Transformers

Dynamic Token Normalization Improves Vision Transformers This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision T

Wenqi Shao 20 Oct 09, 2022
Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Michael Nielsen 13.9k Dec 26, 2022
[CVPR'21] Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild

IVOS-W Paper Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild Zhaoyun Yin, Jia Zheng, Weixin Luo, Shenhan Qian, Hanli

SVIP Lab 38 Dec 12, 2022
SimpleDepthEstimation - An unified codebase for NN-based monocular depth estimation methods

SimpleDepthEstimation Introduction This is an unified codebase for NN-based monocular depth estimation methods, the framework is based on detectron2 (

8 Dec 13, 2022
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
Official implementation of ACMMM'20 paper 'Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework'

Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework Official code for paper, Self-supervised Video Representation Le

Li Tao 103 Dec 21, 2022
Measure WWjj polarization fraction

WlWl Polarization Measure WWjj polarization fraction Paper: arXiv:2109.09924 Notice: This code can only be used for the inference process, if you want

4 Apr 10, 2022