Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Overview

Memory Efficient Attention

arXiv PyPI version

This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch.

Implementation is almost same as the one proposed in the paper, with additional masking and adding bias compatibility, batch dimensions support and PyTorch implementation. For computing attention, the proposed method requires only O(sqrt(n)) memory, and the provided functions can be used as a drop-in replacement for attention calculation.

Important Note: This implementation is a trade-off between memory requirements and runtime, so you should adjust key_chunk_size and query_chunk_size parameters to achieve the best configuration for your usecase. Here is a note from the paper's authors:

While a constant chunk size for the queries and a chunk size of sqrt(n) for the keys and values is optimal for memory consumption, the runtime is also affected by the choice of chunk size in practice, which is heavily affected by the choice of hardware. Ultimately, we have to leave this trade-off to the programmer, and expose the chunk sizes as arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.

Quick Start

  1. Install the library
# for Jax
pip install memory-efficient-attention[jax]
# for PyTorch
pip install memory-efficient-attention[torch]
# for Running Tests
pip install memory-efficient-attention[testing]
  1. Compute attention with the proper function
0.5 bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100 # Adjust chunk sizes efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)">
import numpy as np
# for PyTorch
from memory_efficient_attention import efficient_dot_product_attention_pt
# or for Jax
from memory_efficient_attention import efficient_dot_product_attention_jax

# Random Data (batch dimensions are not necessary)
b = 8
query = np.random.rand(1, b, 128, 16, 8).astype("float32")
key = np.random.rand(1, b, 128, 16, 8).astype("float32")
value = np.random.rand(1, b, 128, 16, 8).astype("float32")
# optional, for casual tasks, ...
mask = np.random.rand(1, b, 16, 128, 128) > 0.5
bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100

# Adjust chunk sizes        
efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...)

Citation

Please cite if this implementation helps your research. You can use the following BibTeX entry:

@misc{memory_efficient_attention,
	title = {Memory Efficient Attention},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/memory-efficient-attention}},
	year = {2021}
}

Also, for the paper:

@misc{rabe2021selfattention,
      title={Self-attention Does Not Need $O(n^2)$ Memory}, 
      author={Markus N. Rabe and Charles Staats},
      year={2021},
      eprint={2112.05682},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
You might also like...
Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory"

memory_efficient_attention.pytorch A human-readable PyTorch implementation of "Self-attention Does Not Need O(n^2) Memory" (Rabe&Staats'21). def effic

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Ă–zdemir Cetin & Heinz Koeppl | Pr

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

Implementation of
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation
Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation

STCN Rethinking Space-Time Networks with Improved Memory Coverage for Efficient Video Object Segmentation Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [a

Comments
  • feat: output_attentions

    feat: output_attentions

    I'm looking into hacking some of the models in the transformers library to use this library for attention, and I don't see a way to support output_attentions yet. This is a flag passed in transformers, where the attention weights are preserved and returned to the user, if it is set.

    I looked a little at implementing this in the torch backend, and I note the scan() function provides for only a single tensor return value. It seems to me that scan() function would be most clearly replaced by a for loop, but it could also be modified to handle tuples, or return_weights could be handled via accessing nonlocal data in some way instead of returning them through the chunk scanner. I'm also not sure how the output would best be passed to the user.

    Edit: Draft implementation 01/28 at https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:faba6371ac7faaa2040a2c26e15ae7ab87f94ce4 . I ended up extending the scan function for parity between implementations. Edit 2: Turns out it's the postsoftmax attention weights, not the presoftmax attention weights. I've updated this post and the draft implementation for this output: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/main...xloem:return_weights

    opened by xloem 4
  • Provide a flag for the user to receive attention weights

    Provide a flag for the user to receive attention weights

    This is my draft code for #1. I saw this feature in the transformers library and wanted to implement it here.

    I'm curious what you think about this feature and implementation.

    The code is simply slightly instrumented so that the final attention weights can be returned to the user. Tests are augmented to test this use. In utils, the scan function is expanded to handle tuples.

    A change to dynamic_slice crept in from dev, to use slices rather than index_slice. I've retained this change because it looks like it would execute faster to me, but it can be removed.

    Rebased and squashed from 84724e1de4721ea0333d6bdbb91e8bce74fbeac .

    opened by xloem 2
  • Improve performance via batched-matmul and fused multiplies

    Improve performance via batched-matmul and fused multiplies

    Many thanks for providing this reference implementation.

    I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
    https://github.com/Birch-san/diffusers/pull/1/commits/04372140a25d7f53549175f1f196599c3e9bf3a5

    Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

    I compared the speed of my optimized version, against the implementation in this repository. this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

    compared to the implementation in this repository:
    my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

    here's my optimized implementation:
    https://github.com/Birch-san/diffusers/pull/1

    batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

    code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

    query = query.transpose(1,2).flatten(end_dim=1)
    key = key.transpose(1,2).flatten(end_dim=1)
    value = value.transpose(1,2).flatten(end_dim=1)
    

    the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

    result.unflatten(0, (-1, attn.heads)).transpose(1,2)
    

    I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

    EDIT:
    I have now added fast-paths for:

    • skipping kv-chunking when kv_chunk_size >= k_tokens
      • this turns the algorithm into "attention slicing"
    • skipping q-chunking when q_chunk_size >= q_tokens
    • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
    • skipping all chunking when the [email protected] matmul requires fewer bytes than a user-provided threshold
    opened by Birch-san 1
Releases(0.1.3)
  • 0.1.2(Mar 7, 2022)

    What's Changed

    This update fixes torch device handling issues in code. GPU and other kinds of tensors can be used safely.

    • Update utils.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5
    • Update attention_torch.py by @yhgon in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/6

    New Contributors

    • @yhgon made their first contribution in https://github.com/AminRezaei0x443/memory-efficient-attention/pull/5

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1.0...0.1.2

    Source code(tar.gz)
    Source code(zip)
  • 0.1.1.0(Feb 3, 2022)

    Added mask, bias calculation functions for custom and memory efficient chunks computation. So now sublinear memory computation mask, bias are possible.

    Full Changelog: https://github.com/AminRezaei0x443/memory-efficient-attention/compare/0.1.1...0.1.1.0

    Source code(tar.gz)
    Source code(zip)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
Code to generate datasets used in "How Useful is Self-Supervised Pretraining for Visual Tasks?"

Synthetic dataset rendering Framework for producing the synthetic datasets used in: How Useful is Self-Supervised Pretraining for Visual Tasks? Alejan

Princeton Vision & Learning Lab 21 Apr 29, 2022
ML-based medical imaging using Azure

Disclaimer This code is provided for research and development use only. This code is not intended for use in clinical decision-making or for any other

Microsoft Azure 68 Dec 23, 2022
FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning

FEDn is an open-source, modular and ML-framework agnostic framework for Federated Machine Learning (FedML) developed and maintained by Scaleout Systems. FEDn enables highly scalable cross-silo and cr

Scaleout 75 Nov 09, 2022
Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Trevor Ablett*, Bryan Chan*,

STARS Laboratory 8 Sep 14, 2022
AI that generate music

PianoGPT ai that generate music try it here https://share.streamlit.io/annasajkh/pianogpt/main/main.py or here https://huggingface.co/spaces/Annas/Pia

Annas 28 Nov 27, 2022
Using VapourSynth with super resolution models and speeding them up with TensorRT.

VSGAN-tensorrt-docker Using image super resolution models with vapoursynth and speeding them up with TensorRT. Using NVIDIA/Torch-TensorRT combined wi

111 Jan 05, 2023
Author's PyTorch implementation of TD3 for OpenAI gym tasks

Addressing Function Approximation Error in Actor-Critic Methods PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If y

Scott Fujimoto 1.3k Dec 25, 2022
ML for NLP and Computer Vision.

Sparrow is our open-source ML product. It runs on Skipper MLOps infrastructure.

Katana ML 2 Nov 28, 2021
Label-Free Model Evaluation with Semi-Structured Dataset Representations

Label-Free Model Evaluation with Semi-Structured Dataset Representations Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch

8 Oct 06, 2022
Explainability of the Implications of Supervised and Unsupervised Face Image Quality Estimations Through Activation Map Variation Analyses in Face Recognition Models

Explainable_FIQA_WITH_AMVA Note This is the official repository of the paper: Explainability of the Implications of Supervised and Unsupervised Face I

3 May 08, 2022
This repo is official PyTorch implementation of MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices(CVPRW 2021).

Github Code of "MobileHumanPose: Toward real-time 3D human pose estimation in mobile devices" Introduction This repo is official PyTorch implementatio

Choi Sang Bum 203 Jan 05, 2023
Weakly Supervised Dense Event Captioning in Videos, i.e. generating multiple sentence descriptions for a video in a weakly-supervised manner.

WSDEC This is the official repo for our NeurIPS paper Weakly Supervised Dense Event Captioning in Videos. Description Repo directories ./: global conf

Melon(Xuguang Duan) 96 Nov 01, 2022
Tesla Light Show xLights Guide With python

Tesla Light Show xLights Guide Welcome to the Tesla Light Show xLights guide! You can create and run your own light shows on Tesla vehicles. Running a

Tesla, Inc. 2.5k Dec 29, 2022
Non-Official Pytorch implementation of "Face Identity Disentanglement via Latent Space Mapping" https://arxiv.org/abs/2005.07728 Using StyleGAN2 instead of StyleGAN

Face Identity Disentanglement via Latent Space Mapping - Implement in pytorch with StyleGAN 2 Description Pytorch implementation of the paper Face Ide

Daniel Roich 58 Dec 24, 2022
VOneNet: CNNs with a Primary Visual Cortex Front-End

VOneNet: CNNs with a Primary Visual Cortex Front-End A family of biologically-inspired Convolutional Neural Networks (CNNs). VOneNets have the followi

The DiCarlo Lab at MIT 99 Dec 22, 2022
A minimal implementation of face-detection models using flask, gunicorn, nginx, docker, and docker-compose

Face-Detection-flask-gunicorn-nginx-docker This is a simple implementation of dockerized face-detection restful-API implemented with flask, Nginx, and

Pooya-Mohammadi 30 Dec 17, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
LinkNet - This repository contains our Torch7 implementation of the network developed by us at e-Lab.

LinkNet This repository contains our Torch7 implementation of the network developed by us at e-Lab. You can go to our blogpost or read the article Lin

e-Lab 158 Nov 11, 2022
NeurIPS 2021, "Fine Samples for Learning with Noisy Labels"

[Official] FINE Samples for Learning with Noisy Labels This repository is the official implementation of "FINE Samples for Learning with Noisy Labels"

mythbuster 27 Dec 23, 2022