Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Overview

Memory Compressed Attention

Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers both the causal and non-causal variant, and will take care of the padding if the sequence length is not divisible by the compression ratio.

The code also resolves an edge-case where the very first query have no keys to attend to in the auto-regressive scenario. The solution is to use null key/values, appended to the final compressed set, so that there is always at least 1 key for all queries to attend to.

Install

$ pip install memory_compressed_attention

Usage

import torch
from memory_compressed_attention import MemoryCompressedAttention

attn = MemoryCompressedAttention(
    dim = 512,
    heads = 8,                 # number of heads
    causal = False,            # auto-regressive or not
    compression_factor = 3,    # compression ratio
    dropout = 0.1              # dropout post-attention
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

attn(x, input_mask = mask) # (1, 1024, 512)

Citations

@misc{liu2018generating,
    title={Generating Wikipedia by Summarizing Long Sequences},
    author={Peter J. Liu and Mohammad Saleh and Etienne Pot and Ben Goodrich and Ryan Sepassi and Lukasz Kaiser and Noam Shazeer},
    year={2018},
    eprint={1801.10198},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
You might also like...
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

 Attention for PyTorch with Linear Memory Footprint
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting.

GAN Memory for Lifelong learning This is a pytorch implementation of the NeurIPS paper GAN Memory with No Forgetting. Please consider citing our paper

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Comments
  • The order of masking and softmax operation

    The order of masking and softmax operation

    Hi,

    In memory_compressed_attention.py, I'm wondering if we need to do softmax operation after masking? Btw, if the entry in the mask should be float('-inf') instead of -float('-inf')? If I make something wrong, please correct me.

    image

    Thanks!

    opened by cfeng16 3
  • mask error in attention

    mask error in attention

    Very grateful for your pioneering work! I want to use it in Standard Transformer released in http://nlp.seas.harvard.edu/2018/04/03/attention.html. but it mat a mask error in training. more detail information shown as follow, the code i use: image class ConvCompress(nn.Module): def init(self, dim, ratio = 2, groups = 1): super(ConvCompress, self).init() self.conv = nn.Conv1d(dim, dim, ratio, stride = ratio, groups = groups) #self.linear = nn.Linear(dim, dim)

    def forward(self, mem):
        mem = mem.transpose(1, 2)
        compressed_mem = self.conv(mem)
        return compressed_mem.transpose(1, 2)
    

    class MemoryCompressedAttention(nn.Module): def init(self, h, d_model, compression_factor = 2, dropout = 0.1): super(MemoryCompressedAttention, self).init() assert (d_model % h) == 0, 'dimension must be divisible by number of heads' self.h = h self.d_model = d_model self.d_k = d_model // h

        self.compression_factor = compression_factor
        self.compress_fn = ConvCompress(d_model, compression_factor, groups = h)
    
        #self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
        self.wq = nn.Linear(d_model, d_model, bias = False)
        self.wk = nn.Linear(d_model, d_model, bias = False)
        self.wv = nn.Linear(d_model, d_model, bias = False)
    
        self.wo = nn.Linear(d_model, d_model)
    
        self.dropout = nn.Dropout(dropout)
    
        #self.null_k = nn.Parameter(torch.zeros(1, 1, d_model))
        #self.null_v = nn.Parameter(torch.zeros(1, 1, d_model))
    
    def forward(self, query, key, value, mask = None):
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        t = query.size(1)
        cf = self.compression_factor
    
        query = self.wq(query)
        key = self.wk(key)
        value = self.wv(value)
    
        # make sure keys and values sequence lengths
        # are divisible by the compression factor
        padding = cf - (t % cf)
        if padding != 0:
            key, value = map(lambda t: F.pad(t, (0, 0, padding, 0)), (key, value))
    
    
        # compress keys and values
        key, value = map(self.compress_fn, (key, value))
    
        # attach a null key and value, in the case that the first query has no keys to pay attention to
        null_k = nn.Parameter(torch.zeros(key.size(0), 1, self.d_model)).cuda()
        null_v = nn.Parameter(torch.zeros(value.size(0), 1, self.d_model)).cuda()
    
        key = torch.cat((null_k, key), dim=1)
        value = torch.cat((null_v, value), dim=1)
        
        # merge heads
        #query, key, value = map(lambda t: t.reshape(*t.shape[:2], h, -1).transpose(1, 2), (query, key, value))
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query = query.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        key = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        value = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
    
      
        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)
    
        # 3) "Concat" using a view and apply a final linear.   # split heads and combine
        x = x.contiguous().view(nbatches, -1, self.d_model)
        out = self.wo(x)
    
        return out
    

    The error was show that image

    I want to know how to fix it, and how to do mask for N*M matrix??

    opened by HN123-123 0
Releases(0.0.5)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Code for the paper: Sketch Your Own GAN

Sketch Your Own GAN Project | Paper | Youtube | Slides Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to mat

677 Dec 28, 2022
This is 2nd term discrete maths project done by UCU students that uses backtracking to solve various problems.

Backtracking Project Sponsors This is a project made by UCU students: Olha Liuba - crossword solver implementation Hanna Yershova - sudoku solver impl

Dasha 4 Oct 17, 2021
JumpDiff: Non-parametric estimator for Jump-diffusion processes for Python

jumpdiff jumpdiff is a python library with non-parametric Nadaraya─Watson estimators to extract the parameters of jump-diffusion processes. With jumpd

Rydin 28 Dec 10, 2022
S-attack library. Official implementation of two papers "Are socially-aware trajectory prediction models really socially-aware?" and "Vehicle trajectory prediction works, but not everywhere".

S-attack library: A library for evaluating trajectory prediction models This library contains two research projects to assess the trajectory predictio

VITA lab at EPFL 71 Jan 04, 2023
Subgraph Based Learning of Contextual Embedding

SLiCE Self-Supervised Learning of Contextual Embeddings for Link Prediction in Heterogeneous Networks Dataset details: We use four public benchmark da

Pacific Northwest National Laboratory 27 Dec 01, 2022
This is an official implementation for "AS-MLP: An Axial Shifted MLP Architecture for Vision".

AS-MLP architecture for Image Classification Model Zoo Image Classification on ImageNet-1K Network Resolution Top-1 (%) Params FLOPs Throughput (image

SVIP Lab 106 Dec 12, 2022
Hierarchical Metadata-Aware Document Categorization under Weak Supervision (WSDM'21)

Hierarchical Metadata-Aware Document Categorization under Weak Supervision This project provides a weakly supervised framework for hierarchical metada

Yu Zhang 53 Sep 17, 2022
Kaggle Feedback Prize - Evaluating Student Writing 15th solution

Kaggle Feedback Prize - Evaluating Student Writing 15th solution First of all, I would like to thank the excellent notebooks and discussions from http

Lingyuan Zhang 6 Mar 24, 2022
code for generating data set ES-ImageNet with corresponding training code

es-imagenet-master code for generating data set ES-ImageNet with corresponding training code dataset generator some codes of ODG algorithm The variabl

Ordinarabbit 18 Dec 25, 2022
A library to inspect itermediate layers of PyTorch models.

A library to inspect itermediate layers of PyTorch models. Why? It's often the case that we want to inspect intermediate layers of a model without mod

archinet.ai 380 Dec 28, 2022
Sparse Physics-based and Interpretable Neural Networks

Sparse Physics-based and Interpretable Neural Networks for PDEs This repository contains the code and manuscript for research done on Sparse Physics-b

28 Jan 03, 2023
ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representation from common sense knowledge graphs.

ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representa

Bats Research 94 Nov 21, 2022
An updated version of virtual model making

Model-Swap-Face v2   这个项目是基于stylegan2 pSp制作的,比v1版本Model-Swap-Face在推理速度和图像质量上有一定提升。主要的功能是将虚拟模特进行环球不同区域的风格转换,目前转换器提供西欧模特、东亚模特和北非模特三种主流的风格样式,可帮我们实现生产资料零成

seeprettyface.com 62 Dec 09, 2022
[NeurIPS'21 Spotlight] PyTorch code for our paper "Aligned Structured Sparsity Learning for Efficient Image Super-Resolution"

ASSL This repository is for a new network pruning method (Aligned Structured Sparsity Learning, ASSL) for efficient single image super-resolution (SR)

Huan Wang 47 Nov 28, 2022
Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

Facebook Research 1k Dec 31, 2022
Official implementation of NeurIPS 2021 paper "Contextual Similarity Aggregation with Self-attention for Visual Re-ranking"

CSA: Contextual Similarity Aggregation with Self-attention for Visual Re-ranking PyTorch training code for CSA (Contextual Similarity Aggregation). We

Hui Wu 19 Oct 21, 2022
Study of human inductive biases in CNNs and Transformers.

Are Convolutional Neural Networks or Transformers more like human vision? This repository contains the code and fine-tuned models of popular Convoluti

Shikhar Tuli 39 Dec 08, 2022
TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular potentials

TorchMD-net TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular

TorchMD 104 Jan 03, 2023
Code for SentiBERT: A Transferable Transformer-Based Architecture for Compositional Sentiment Semantics (ACL'2020).

SentiBERT Code for SentiBERT: A Transferable Transformer-Based Architecture for Compositional Sentiment Semantics (ACL'2020). https://arxiv.org/abs/20

Da Yin 66 Aug 13, 2022