Pytorch implementation of Compressive Transformers, from Deepmind

Overview

Compressive Transformer in Pytorch

Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.

PyPI version

Install

$ pip install compressive_transformer_pytorch

Usage

import torch
from compressive_transformer_pytorch import CompressiveTransformer

model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    ff_glu = True                  # use GLU variant for feedforward
)

inputs = torch.randint(0, 256, (1, 2048))
masks = torch.ones_like(inputs).bool()

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)
masks = masks.reshape(1, -1, 1024).transpose(0, 1)

logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits,        _, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.

import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    mem_len = 1024,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss, _ in model(inputs, return_loss = True):
    (loss + aux_loss).backward()
    # optimizer step and zero grad

# ... after much training ...

# generation is also greatly simplified and automated away
# just pass in the prime, which can be 1 start token or any length
# all is taken care of for you

prime = torch.ones(1, 1).cuda()  # assume 1 is start token
sample = model.generate(prime, 4096)

Citations

@misc{rae2019compressive,
    title   = {Compressive Transformers for Long-Range Sequence Modelling},
    author  = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
    year    = {2019},
    eprint  = {1911.05507},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{parisotto2019stabilizing,
    title   = {Stabilizing Transformers for Reinforcement Learning},
    author  = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year    = {2019},
    eprint  = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and
      Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • aux_loss does not update any weigth

    aux_loss does not update any weigth

    Hi lucidrains, thanks for your implementation, it is very elegant and helped me a lot with my disertation. Anyway I can't understand a particular: it seems like aux_loss is not related to any weight because of the detaching in the last part of the SelfAttention layer. With the following code, for example, I get that there is no layer optimized by aux_loss:

    import torch
    from compressive_transformer_pytorch import CompressiveTransformer
    from compressive_transformer_pytorch import AutoregressiveWrapper
    
    model = CompressiveTransformer(
        num_tokens = 20000,
        dim = 512,
        depth = 6,
        seq_len = 1024,
        mem_len = 1024,
        cmem_len = 256,
        cmem_ratio = 4,
        memory_layers = [5,6]
    ).cuda()
    
    model = AutoregressiveWrapper(model)
    
    inputs = torch.randint(0, 20000, (1, 1024)).cuda()
    
    optimizer = torch.optim.Adam(model.parameters())
    
    for loss, aux_loss, _ in model(inputs, return_loss = True):
        optimizer.zero_grad(set_to_none=True)
        loss.backward(retain_graph=True)
        print("OPTIMIZED BY LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
        optimizer.zero_grad(set_to_none=True)
        aux_loss.backward(retain_graph=True)
        print("OPTIMIZED BY AUX_LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
    

    I am not expert about the PyTorch mechanisms, so maybe I am getting something wrong. Again thank you

    opened by StefanoBerti 3
  • How to use this for speech/audio generation?

    How to use this for speech/audio generation?

    Great work Phil! In their paper, the authors applied this model to speech modeling, how would you advise on what should I change to use for speech. Because in speech, the data are signals, we do not have num_tokens, nor do we have emb_dim. Our data input is simply, [batch, channel, time]. Any advice?

    opened by jinglescode 3
  • [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    hello, I run code "examples/enwik8_simple" now, and I got error as follows:

    train.py:65: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) training: 0%| | 0/100000 [00:00<?, ?it/s] Traceback (most recent call last): File "train.py", line 101, in <module> for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True): File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/autoregressive_wrapper.py", line 151, in forward logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 338, in f orward x, = ff(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 84, in fo rward out = self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 106, in f orward return self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 140, in f orward x = self.act(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 122, in f orward return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) NameError: name 'math' is not defined

    so I inserted "import math" into compressive_transformer_pytorch.py file and it work well. I hope you modify compressive_transformer_pytorch.py code.

    opened by dinoSpeech 3
  • Training enwik8 but loss fail to converge

    Training enwik8 but loss fail to converge

    Hi lucidrains, I appreciate your implementation very much, and it helps me a lot with understanding compressive transformer. However when I tried running your code (enwik8 and exactly the same code in github), and the loss failed to converge after 100 epochs. Is this in expectation ? Or should I do other additional effort to improve, for example tokenizing the raw data in enwik8 and remove all the xml tags ? The figure below is the training and validation loss while I train enwik8 with the same code as in github.

    截圖 2021-03-26 下午5 40 16 截圖 2021-03-26 下午5 41 22

    Thanks and look forward to your reply!

    opened by KaiPoChang 2
  • Details about text generation

    Details about text generation

    Hi lucidrains, Thank you for your excellent code. I am curious about the generation scripts. Could you tell me how to generate text with the compressive transformer? Because it has the compressive memory, maybe we cannot use the current predicted word as the input for the next generation (input length ==1). In addition, if the prompt has 100 words and we use tokens [0:100], tokens[1:101], tokens[2:102]... as the input for the following timesteps, the tokens[1:100] may overlap with the memory, because the memory already contains hidden states for tokens[1:100].

    I would be very appeciated if you can provide the generation scripts!

    Thank you

    opened by theseventhflow 3
  • Links to original tf code - fyi

    Links to original tf code - fyi

    After reading deepmind blog post I was looking forward to downloading model but no luck. Looking forward to your implementation.

    You may be aware of this post and link but if not this is the coder's original tf implementation. Hope it helps.

    Copy of comment to original model request:

    https://github.com/huggingface/transformers/issues/4688

    Interested in model weights too but currently not available. Author does mention releasing tf code here:

    https://news.ycombinator.com/item?id=22290227

    Requires tf 1.15+ and deepmind/sonnet ver 1.36. Link to python script here:

    https://github.com/deepmind/sonnet/blob/cd5b5fa48e15e4d020f744968f5209949ebe750f/sonnet/python/modules/nets/transformer.py#L915

    Have tried running as-is but doesn't appear to have options for training on custom data as per the paper and available data sets.

    opened by GenTxt 8
Releases(0.4.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
arxiv-sanity, but very lite, simply providing the core value proposition of the ability to tag arxiv papers of interest and have the program recommend similar papers.

arxiv-sanity, but very lite, simply providing the core value proposition of the ability to tag arxiv papers of interest and have the program recommend similar papers.

Andrej 671 Dec 31, 2022
This is a repository for a No-Code object detection inference API using the OpenVINO. It's supported on both Windows and Linux Operating systems.

OpenVINO Inference API This is a repository for an object detection inference API using the OpenVINO. It's supported on both Windows and Linux Operati

BMW TechOffice MUNICH 68 Nov 24, 2022
Offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation

Shunted Transformer This is the offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation by Sucheng Ren, Daquan Zhou, Shengf

156 Dec 27, 2022
Hcaptcha-challenger - Gracefully face hCaptcha challenge with Yolov5(ONNX) embedded solution

hCaptcha Challenger 🚀 Gracefully face hCaptcha challenge with Yolov5(ONNX) embe

593 Jan 03, 2023
ThunderSVM: A Fast SVM Library on GPUs and CPUs

What's new We have recently released ThunderGBM, a fast GBDT and Random Forest library on GPUs. add scikit-learn interface, see here Overview The miss

Xtra Computing Group 1.4k Dec 22, 2022
The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

Sun Yi 201 Nov 21, 2022
Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that gene

Guan-Horng Liu 43 Jan 03, 2023
Language model Prompt And Query Archive

LPAQA: Language model Prompt And Query Archive This repository contains data and code for the paper How Can We Know What Language Models Know? Install

127 Dec 20, 2022
A scanpy extension to analyse single-cell TCR and BCR data.

Scirpy: A Scanpy extension for analyzing single-cell immune-cell receptor sequencing data Scirpy is a scalable python-toolkit to analyse T cell recept

ICBI 145 Jan 03, 2023
Distributed Asynchronous Hyperparameter Optimization in Python

Hyperopt: Distributed Hyperparameter Optimization Hyperopt is a Python library for serial and parallel optimization over awkward search spaces, which

6.5k Jan 01, 2023
This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the time series forecasting research space.

TSForecasting This repository contains the implementations related to the experiments of a set of publicly available datasets that are used in the tim

Rakshitha Godahewa 80 Dec 30, 2022
Generative Adversarial Networks(GANs)

Generative Adversarial Networks(GANs) Vanilla GAN ClusterGAN Vanilla GAN Model Structure Final Generator Structure A MLP with 2 hidden layers of hidde

Zhenbang Feng 2 Nov 05, 2021
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

869 Jan 07, 2023
Face Recognition plus identification simply and fast | Python

PyFaceDetection Face Recognition plus identification simply and fast Ubuntu Setup sudo pip3 install numpy sudo pip3 install cmake sudo pip3 install dl

Peyman Majidi Moein 16 Sep 22, 2022
A lightweight face-recognition toolbox and pipeline based on tensorflow-lite

FaceIDLight 📘 Description A lightweight face-recognition toolbox and pipeline based on tensorflow-lite with MTCNN-Face-Detection and ArcFace-Face-Rec

Martin Knoche 16 Dec 07, 2022
Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads-Tutorial-3 Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads Inc 2 Jan 03, 2022
Heterogeneous Deep Graph Infomax

Heterogeneous-Deep-Graph-Infomax Parameter Setting: HDGI-A: Node-level dimension: 16 Attention head: 4 Semantic-level attention vector: 8 learning rat

52 Oct 31, 2022
Keyword2Text This repository contains the code of the paper: "A Plug-and-Play Method for Controlled Text Generation"

Keyword2Text This repository contains the code of the paper: "A Plug-and-Play Method for Controlled Text Generation", if you find this useful and use

57 Dec 27, 2022
Official code for ICCV2021 paper "M3D-VTON: A Monocular-to-3D Virtual Try-on Network"

M3D-VTON: A Monocular-to-3D Virtual Try-On Network Official code for ICCV2021 paper "M3D-VTON: A Monocular-to-3D Virtual Try-on Network" Paper | Suppl

109 Dec 29, 2022
Earthquake detection via fiber optic cables using deep learning

Earthquake detection via fiber optic cables using deep learning Author: Fantine Huot Getting started Update the submodules After cloning the repositor

Fantine 4 Nov 30, 2022