Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

Overview

Memorizing Transformers - Pytorch

Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch

This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.

Install

$ pip install memorizing-transformers-pytorch

Usage

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,                 # number of tokens
    dim = 512,                          # dimension
    dim_head = 64,                      # dimension per attention head
    depth = 8,                          # number of layers
    memorizing_layers = (4, 5),         # which layers to have ANN memories
    max_knn_memories = 64000,           # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)
    num_retrieved_memories = 32,        # number of ANN memories to retrieve
    clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example)

logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000)

You can make the KNN memories read-only by setting add_knn_memory on forward to False

ex.

logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated

With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory)

import torch
from memorizing_transformers_pytorch import MemorizingTransformer

model = MemorizingTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    memorizing_layers = (4, 5),
    max_knn_memories = 64000,
    num_retrieved_memories = 32,
    clear_memories_on_sos_token_id = 1,
    xl_memory_layers = (2, 3, 4, 5),      # xl memory layers - (https://arxiv.org/abs/2007.03356 shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory
    xl_max_memories = 512,                # number of xl memories to keep
    shift_knn_memories_down = 1,          # let a layer look at the KNN memories this number of layers above
    shift_xl_memories_down = 1,           # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper
)

data = torch.randint(0, 20000, (2, 1024)) # mock data

xl_memories = None

with model.knn_memories_context(batch_size = 2) as knn_memories:
    logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
    logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)

    # ... and so on

KNN Memory

This repository contains a wrapper around Faiss that can automatically store and retrieve key / values

import torch
from memorizing_transformers_pytorch import KNNMemory

memory = KNNMemory(
    dim = 64,                   # dimension of key / values
    max_memories = 64000,       # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
    num_indices = 2             # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document
)

memory.add(torch.randn(2, 512, 2, 64))  # (batch, seq, key | value, feature dim)
memory.add(torch.randn(2, 512, 2, 64))

memory.clear([0]) # clear batch 0, if it saw an <sos>

memory.add(torch.randn(2, 512, 2, 64))
memory.add(torch.randn(2, 512, 2, 64))

key_values, mask = memory.search(torch.randn(2, 512, 64), topk = 32)

Training

Enwik8 training

$ python train.py

Todo

  • switch to ivfhnsw and just remember all memories
  • enwik8 demo
  • validation for enwik8
  • solve gradient accumulation problem by offering some way to scope reads and writes to knn memories with another indices array
  • setup text generation with memories
  • figure out how to deal with memories efficiently once capacity has been hit
  • try to speed up reading and writing to knn memories collection with multiprocessing

Citations

@article{wu2022memorizing,
  title   = {Memorizing transformers},
  author  = {Wu, Yuhuai and Rabe, Markus N and Hutchins, DeLesley and Szegedy, Christian},
  journal = {arXiv preprint arXiv:2203.08913},
  year    = {2022}
}
@article{Shazeer2019FastTD,
  title   = {Fast Transformer Decoding: One Write-Head is All You Need},
  author  = {Noam M. Shazeer},
  journal = {ArXiv},
  year    = {2019},
  volume  = {abs/1911.02150}
}
@Article{AlphaFold2021,
  author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
  journal = {Nature},
  title   = {Highly accurate protein structure prediction with {AlphaFold}},
  year    = {2021},
  doi     = {10.1038/s41586-021-03819-2},
  note    = {(Accelerated article preview)},
}
@inproceedings{Rae2020DoTN,
  title   = {Do Transformers Need Deep Long-Range Memory?},
  author  = {Jack W. Rae and Ali Razavi},
  booktitle = {ACL},
  year    = {2020}
}
@misc{ding2021erniedoc,
  title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
  author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
  year    = {2021},
  eprint  = {2012.15688},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@misc{henry2020querykey,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
    year    = {2020},
    eprint  = {2010.04245},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

Memory is Attention through Time - Alex Graves

Comments
  • Arguments to reproduce the models from the original paper?

    Arguments to reproduce the models from the original paper?

    Hi lucidrains,

    This looks like excellent work! I have gone through the original paper and your repo, and am now trying to reproduce the model from the paper as closely as possible. Of course, the modifications you made such as hybrid attention instead of sigmoid gate are fine.

    Specifically, I would like to be able to try some of the variations in Table 4: image

    Suppose I'm interested in the 4th to last row with Context 512 Memory 8192 XL cache 512. Can you help me the model arguments to do that? Here is my initial attempt, with reference to Section 4.2:

    model = MemorizingTransformer(
        num_tokens = 32000, # vocab 32k
        dim = 1024, 
        depth = 12,
        memorizing_layers = 9,
        max_knn_memories = 8192, # Memory column
        num_retrieved_memories = 32,
        clear_memories_on_sos_token_id = 1,
        xl_memory_layers = (6, 7, 8, 9),  # not sure about this?
        xl_max_memories = 512, # XL cache column
        shift_knn_memories_down = 1, 
        shift_xl_memories_down = 1,
        # which argument corresponds to Context column?
    ).cuda()
    
    

    A second question is what are the model arguments to reproduce to first row of Table 4, with no memory nor XL cache? Thanks in advance.

    opened by manestay 1
  • KNNMemory add() does not appear to update self.knns

    KNNMemory add() does not appear to update self.knns

    Thanks for the nice implementation. I've adapted this code for my own use, so I don't have the whole stack that would reproduce this bug. However, you can check for yourself.

    The following code ought to update the KNN objects in the KNNMemory class:

    @delayed
    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
    
    Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    

    [link to that code here]

    However, even after repeated calls to add to the memory, calling KNNMemory.search() results in empty values. If you view self.knns at this point, self.is_trained remains False.

    When I modify the code as follows, this fixes the issue.

    @delayed
    def knn_add(knn, key, db_offset):
        knn.add(key, ids = knn_insert_ids + db_offset)
        return knn
    
    updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
    self.knns = updated_knns
    

    This will allow searches to return actual values.

    opened by vyaivo 0
  • FAISS hard reset

    FAISS hard reset

    Hello and thanks for this implementation!

    Do you know of any solutions to efficiently solve the "hard reset" problem in FAISS? I know that one could use IndexFlatL2 but that's not really efficient.

    Thank you!

    opened by itsdaniele 0
  •  index out of

    index out of

    when I run train.py, error like this ,"index out of range: Tried to access index 10218 out of table with 255 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418"happens

    opened by chxiag 0
  • Support for Multi-GPU training?

    Support for Multi-GPU training?

    Thank you so much for the great implementation. I would like to ask whether your implementation for Memorizing Transformer could support multi-card distributed training like original paper. If you distribute the memorizingtrransformer model you created to each GPU, then every GPU would hold a memory with a retrieval faiss index. Therefore, each model on different GPU holds different memory database and retrieval index, which is different from the original paper. I regard that each model on different GPU should share the same retrieval context. This problem confuses me a lot.

    Thank you so much for your time. Looking forward to your response!

    opened by Victorwz 0
  • Dimensionality of key and values for Attention

    Dimensionality of key and values for Attention

    I have two questions about the key and value calculation in Attention (and similarly for KNNAttention).

    The relevant line is: https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L135

    1. Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?
    2. Why is the last dimension dim_head*2? I get that *2 is for both k and v, but what about dim_head? I thought q, k, v should all have the same final dimension (i.e. inner_dim==dim_head*heads). My understanding is that this means that either a) there is only 1 attention head, or for b) all heads, k and v are shared. Is there a reason this is done, or am I misunderstanding?

    In your Attention class for Performer, q, k, v all have the same dimensions.

    Thanks in advance!

    opened by manestay 8
  • Maybe scale is wrong

    Maybe scale is wrong

    https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L237

    Shouldn't this be (1-scale)?

    opened by denadai2 3
Releases(0.3.10)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This repository implements a brute-force spellchecker utilizing the Damerau-Levenshtein edit distance.

About spellchecker.py Implementing a highly-accurate, brute-force, and dynamically programmed spellchecking program that utilizes the Damerau-Levensht

Raihan Ahmed 1 Dec 11, 2021
A raytrace framework using taichi language

ti-raytrace The code use Taichi programming language Current implement acceleration lvbh disney brdf How to run First config your anaconda workspace,

蕉太狼 73 Dec 11, 2022
Leon is an open-source personal assistant who can live on your server.

Leon Your open-source personal assistant. Website :: Documentation :: Roadmap :: Contributing :: Story 👋 Introduction Leon is an open-source personal

Leon AI 11.7k Dec 30, 2022
A collection of Korean Text Datasets ready to use using Tensorflow-Datasets.

tfds-korean A collection of Korean Text Datasets ready to use using Tensorflow-Datasets. TensorFlow-Datasets를 이용한 한국어/한글 데이터셋 모음입니다. Dataset Catalog |

Jeong Ukjae 20 Jul 11, 2022
A framework for training and evaluating AI models on a variety of openly available dialogue datasets.

ParlAI (pronounced “par-lay”) is a python framework for sharing, training and testing dialogue models, from open-domain chitchat, to task-oriented dia

Facebook Research 9.7k Jan 09, 2023
Python generation script for BitBirds

BitBirds generation script Intro This is published under MIT license, which means you can do whatever you want with it - entirely at your own risk. Pl

286 Dec 06, 2022
TweebankNLP - Pre-trained Tweet NLP Pipeline (NER, tokenization, lemmatization, POS tagging, dependency parsing) + Models + Tweebank-NER

TweebankNLP This repo contains the new Tweebank-NER dataset and off-the-shelf Twitter-Stanza pipeline for state-of-the-art Tweet NLP, as described in

Laboratory for Social Machines 84 Dec 20, 2022
VoiceFixer VoiceFixer is a framework for general speech restoration.

VoiceFixer VoiceFixer is a framework for general speech restoration. We aim at the restoration of severly degraded speech and historical speech. Paper

Leo 174 Jan 06, 2023
Semantic search for quotes.

squote A semantic search engine that takes some input text and returns some (questionably) relevant (questionably) famous quotes. Built with: bert-as-

cjwallace 11 Jun 25, 2022
Official code for "Parser-Free Virtual Try-on via Distilling Appearance Flows", CVPR 2021

Parser-Free Virtual Try-on via Distilling Appearance Flows, CVPR 2021 Official code for CVPR 2021 paper 'Parser-Free Virtual Try-on via Distilling App

395 Jan 03, 2023
Natural language Understanding Toolkit

Natural language Understanding Toolkit TOC Requirements Installation Documentation CLSCL NER References Requirements To install nut you need: Python 2

Peter Prettenhofer 119 Oct 08, 2022
Open source annotation tool for machine learning practitioners.

doccano doccano is an open source text annotation tool for humans. It provides annotation features for text classification, sequence labeling and sequ

7.1k Jan 01, 2023
In this project, we compared Spanish BERT and Multilingual BERT in the Sentiment Analysis task.

Applying BERT Fine Tuning to Sentiment Classification on Amazon Reviews Abstract Sentiment analysis has made great progress in recent years, due to th

Alexander Leonardo Lique Lamas 5 Jan 03, 2022
Need: Image Search With Python

Need: Image Search The problem is that a user needs to search for a specific ima

Surya Komandooru 1 Dec 30, 2021
Framework for fine-tuning pretrained transformers for Named-Entity Recognition (NER) tasks

NERDA Not only is NERDA a mesmerizing muppet-like character. NERDA is also a python package, that offers a slick easy-to-use interface for fine-tuning

Ekstra Bladet 141 Dec 30, 2022
Chinese Named Entity Recognization (BiLSTM with PyTorch)

BiLSTM-CRF for Name Entity Recognition PyTorch version A PyTorch implemention of Bi-LSTM-CRF model for Chinese Named Entity Recognition. 使用 PyTorch 实现

5 Jun 01, 2022
A python script that will use hydra to get user and password to login to ssh, ftp, and telnet

Hydra-Auto-Hack A python script that will use hydra to get user and password to login to ssh, ftp, and telnet Project Description This python script w

2 Jan 16, 2022
WikiPron - a command-line tool and Python API for mining multilingual pronunciation data from Wiktionary

WikiPron WikiPron is a command-line tool and Python API for mining multilingual pronunciation data from Wiktionary, as well as a database of pronuncia

213 Jan 01, 2023
Korean Simple Contrastive Learning of Sentence Embeddings using SKT KoBERT and kakaobrain KorNLU dataset

KoSimCSE Korean Simple Contrastive Learning of Sentence Embeddings implementation using pytorch SimCSE Installation git clone https://github.com/BM-K/

34 Nov 24, 2022
Lattice methods in TensorFlow

TensorFlow Lattice TensorFlow Lattice is a library that implements constrained and interpretable lattice based models. It is an implementation of Mono

504 Dec 20, 2022