Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention

Overview

Sinkhorn Transformer

PyPI version


This is a reproduction of the work outlined in Sparse Sinkhorn Attention, with additional enhancements.

It includes a parameterized sorting network, using sinkhorn normalization to sample a permutation matrix that matches the most relevant buckets of keys to the buckets of queries.

This work also brings in reversible networks and feed forward chunking (concepts introduced from Reformer) to bring about further memory savings.

Open In Colab 204k tokens (demonstration purposes)

Install

$ pip install sinkhorn_transformer

Use

A Sinkhorn Transformer based language model

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

A plain Sinkhorn Transformer, layers of sinkhorn attention

import torch
from sinkhorn_transformer import SinkhornTransformer

model = SinkhornTransformer(
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128
)

x = torch.randn(1, 2048, 1024)
model(x) # (1, 2048, 1024)

Sinkhorn Encoder / Decoder Transformer

import torch
from sinkhorn_transformer import SinkhornTransformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = DE_SEQ_LEN,
    reversible = True,
    return_embeddings = True
).cuda()

dec = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = EN_SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
).cuda()

x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).cuda()
y = torch.randint(0, 20000, (1, EN_SEQ_LEN)).cuda()

x_mask = torch.ones_like(x).bool().cuda()
y_mask = torch.ones_like(y).bool().cuda()

context = enc(x, input_mask=x_mask)
dec(y, context=context, input_mask=y_mask, context_mask=x_mask) # (1, 4096, 20000)

Autopadder

By default the model will complain if given an input that is not a multiple of the bucket size. To avoid having to make the same padding calculations each time, you can use the helper Autopadder class. It will take care of the input_mask for you as well, if given. Contextual key/values and mask are supported as well.

import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 2048,
    bucket_size = 128,
    causal = True
)

model = Autopadder(model, pad_left=True) # autopadder will fetch the bucket size and autopad input

x = torch.randint(0, 20000, (1, 1117)) # odd sequence length
model(x) # (1, 1117, 20000)

Sinkhorn

This repository has diverged from the paper and is now using attention in place of the original sorting net + gumbel sinkhorn sampling. I have not found a noticeable difference in performance yet, and the new scheme allows me to generalize the network to flexible sequence lengths. If you would like to try Sinkhorn, please use the following settings, which only works for non-causal networks.

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128,
    max_seq_len = 8192,
    use_simple_sort_net = True, # turn off attention sort net
    sinkhorn_iter = 7,          # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,              # use sortcut to reduce complexity to linear time
    temperature = 0.75,         # gumbel temperature - default is set at reported best in paper
    non_permutative = False,    # allow buckets of keys to be sorted to queries more than once
)

x = torch.randint(0, 20000, (1, 8192))
model(x) # (1, 8192, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Issues

Decoding and sequence lengths

Sinkhorn, when trained on fixed length sequences, seems to have trouble decoding sequences from scratch, mainly due to the fact that the sorting net has trouble generalizing when the buckets are partially filled with padding tokens.

Fortunately, I think I have found a simple solution. During training, for causal networks, randomly truncate the sequences and force the sorting net to generalize. I have provided a flag (randomly_truncate_sequence) for the AutoregressiveWrapper instance to make this easy.

import torch
from sinkhorn_transformer import SinkhornTransformerLM, AutoregressiveWrapper

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 75,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

I am open to suggestions if someone has found a better solution.

Causal sorting net

There is a potential problem with the causal sorting network, where the decision of which key/value buckets of the past sorts to a bucket is dependent only on the first token and not the rest (due to the bucketing scheme and preventing leakage of future to past).

I have attempted to alleviate this problem by rotating half the heads to the left by bucket size - 1, thereby promoting the last token to be first. This is also the reason why the AutoregressiveWrapper defaults to left padding during training, to always make sure that the last token in the sequence have a say in what to retrieve.

If anyone has found a cleaner solution, please let me know in the issues.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Reformer - https://github.com/lucidrains/reformer-pytorch

Citations

@misc{tay2020sparse,
    title   = {Sparse Sinkhorn Attention},
    author  = {Yi Tay and Dara Bahri and Liu Yang and Donald Metzler and Da-Cheng Juan},
    year    = {2020},
    url.    = {https://arxiv.org/abs/2002.11296}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}

Comments
  • Training falling on version 0.0.14 and 0.0.15

    Training falling on version 0.0.14 and 0.0.15

    Hi, I testing training model on new versions of repo, and I have some troubles with 0.0.14 and 0.0.15. On 0.0.14, model always return nan on forward pass, version 0.0.15 lead to CUDA error:

    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Full error listing:

    ipython-input-7-1329da5363de> in forward(self, inputs, labels)
          7   def forward(self, inputs, labels=None):
          8     loss_mx = labels != -100
    ----> 9     output = self.model(inputs)
         10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
         11     labels = labels[loss_mx].view(-1)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        376         x = self.to_token_emb(x)
        377         x = self.pos_emb(torch.arange(t, device=device)) + x
    --> 378         x = self.sinkhorn_transformer(x)
        379         return self.to_logits(x)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        359 
        360     def forward(self, x, input_mask = None):
    --> 361         return self.layers(x)
        362 
        363 class SinkhornTransformerLM(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        330     def forward(self, x, **kwargs):
        331         x = torch.cat([x, x], dim=-1)
    --> 332         x = self.layers(x, **kwargs)
        333         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
        334 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
        128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
        129 
    --> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
         98         ctx.kwargs = kwargs
         99         for block in blocks:
    --> 100             x = block(x, **kwargs)
        101         ctx.y = x.detach()
        102         ctx.blocks = blocks
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
         51         with torch.no_grad():
         52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
    ---> 53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
         54 
         55         return torch.cat([y1, y2], dim=2)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
         25 
         26         if not set_rng:
    ---> 27             return self.net(*args, **kwargs)
         28 
         29         rng_devices = []
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in <listcomp>(.0)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        112     def forward(self, x, **kwargs):
        113         x = self.norm(x)
    --> 114         return self.fn(x, **kwargs)
        115 
        116 class SortNet(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
        103 
        104     def forward(self, x):
    --> 105         return self.net(x)
        106 
        107 class PreNorm(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
         98     def forward(self, input):
         99         for module in self:
    --> 100             input = module(input)
        101         return input
        102 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
         85 
         86     def forward(self, input):
    ---> 87         return F.linear(input, self.weight, self.bias)
         88 
         89     def extra_repr(self):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
       1591         ret = torch.addmm(bias, input, weight.t())
       1592     else:
    -> 1593         output = input.matmul(weight.t())
       1594         if bias is not None:
       1595             output += bias
    
    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Also, version 0.0.11(and all other version from 0.0.8) work stable.

    opened by blizda 30
  • generation problem in a toy task

    generation problem in a toy task

    Here is the full script for my toy task (x -> xx like "abc" to "abcabc")

    from sinkhorn_transformer import SinkhornTransformerLM
    from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
    
    import random
    import tqdm
    import gzip
    import numpy as np
    import torch
    import torch.optim as optim
    from torch import nn
    from torch.nn import functional as F
    from torch.utils.data import DataLoader, Dataset
    
    # constants
    
    NUM_BATCHES = int(1e5)
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATE_EVERY = 4
    LEARNING_RATE = 1e-4
    VALIDATE_EVERY  = 100
    GENERATE_EVERY  = 100
    ENC_SEQ_LEN=16
    DEC_SEQ_LEN=40
    NUM_TOKENS = 256 + 2
    BUCKET_SIZE = 8
    
    # helpers
    
    def top_k(logits, thres = 0.9):
        k = int((1 - thres) * logits.shape[-1])
        val, ind = torch.topk(logits, k)
        probs = torch.full_like(logits, float('-inf'))
        probs.scatter_(1, ind, val)
        return probs
    
    
    def cycle():
        while True:
            source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
    
            target = torch.cat((source, source), 1)
            prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
            target = torch.cat((prefix, target), axis=1)
    
            x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
            y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()
    
    
            yield (source, target, x_mask, y_mask)
    
    # instantiate model
    
    class MySinkhornTransformer(nn.Module):
        def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
            super().__init__()
            
            self.pad_token = 0
            self.sos_token = 1
    
            self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                             reversible=True, return_embeddings=True)
            self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                             receives_context=True, context_bucket_size=bucket_size, reversible=True)
            self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
        
        @torch.no_grad()
        def generate(self, x, x_mask):
            context = self.enc(x, input_mask=x_mask)
            start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()
    
            return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)
    
        def forward(self, x, y, x_mask, y_mask, return_loss):
            context = self.enc(x, input_mask=x_mask)
            return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)
    
    
    model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
    model.cuda()
    # optimizer
    
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # training
    
    for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
        model.train()
    
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
    
        if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                source, target, x_mask, y_mask = next(cycle())
                loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
                print(f'validation loss: {loss.item()}')
    
        if i % GENERATE_EVERY == 0:
            model.eval()
            
            source, target, x_mask, y_mask = next(cycle())
            
            sample = model.generate(x=source, x_mask=x_mask)
            print("input:  ", source[0])
            print("model output:  ", sample[0])
    

    After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: "x,x,x,x,x,y,y,y,y,y" like "aaaabbbb" instead of "abcdabcd". I was wondering what might be the underlying issue. Do you got any idea?

    opened by py4 18
  • Sortcut variant and bucket size

    Sortcut variant and bucket size

    From the paper it looks to me that for the sortcut variant, the queries are allowed to attend to all the key buckets post truncation for the non-causal case. If this is correct, won't the output be the same no matter what the bucket size is for the query.

    Based on my understanding of the paper, the authors only report results selecting 2 top key/value blocks with bucket sizes 8,16,32. They do not mention having different bucket sizes for query and key/value which I found in this repo.

    opened by jsmith1915 16
  • A noobish question about training...

    A noobish question about training...

    @lucidrains

    First of all, everything seems to be working now in Google Colab so thank you very much for fixing it.

    I have a quick question about training if you do not mind...

    I get very high loss results. Here is the example:

    training: 1%| | 509/100000 [1:55:22<387:23:57, 14.02s/it]training loss: 2.495532512664795

    Is this normal and I simply need to train more? Or does it mean that there is a problem somewhere?

    2.49 in 2 hours is way too much IMHO. I am pretty sure I am not doing it right so your advice would be really apppreciated.

    Thank you.

    P.S. I am running your wiki8 example in Google Colab Pro.

    opened by asigalov61 9
  • several question about implementation

    several question about implementation

    I am currently trying to implement Sparse Sinkhorn attention (non-causal, self-attention, pretrain for MLM task) with tensorflow and I would appreciate it if you could answer several questions about your code.

    1. I did not quite understand from the paper what is happening with queries and keys in SortCut, when we cut off most of the blocks. In your code it seems like you broadcast keys/queries like this [number_blocks, block_len, size_per_head] -> [:n_sortcut, block_len, size_per_head] -> [1, n_sortcut * block_len, size_per_head] -> [number_blocks, n_sortcut * block_len, size_per_head] I understand that pytorch handling expand_dim automatically and memory is preserved, but I do not understand how does the memory consumption stay linear: the resulting attention scores (dots) are still [number_blocks, block_len, num_sortcut * block_len] which is good but quadratic. Do I miss something?

    2. In the case of SortCut what is the meaning of the softmax(A)*Q? For example, for full attention outputs are vectors which are weighted average of other vectors; For local block attention outputs are vectors that are weighted average of other vectors in the block. How would you interpret the output of SortCut attention layer?

    3. In your code you concatenate regular keys and values with permuted/sorted keys and values, but in the paper (3.2. part) it seems like simple addition. Why is it different from the paper? Actually I can understand concatenation more than summation (because I am confused which attention masks to use in this case: from the regular block of permuted).

    4. Seems like you are using concatenated queries and keys as an input to make a SortNet. Maybe I completely missed the point, but I did not find anything about this in the article: the input sequence is used for SortNet instead of its projections. In fact, for some reason, I cannot make it converge if I use query and keys (maybe for a completely unrelated reasons). Have you tried to compare input sequence vs concat(q,k)?

    opened by w4-magnes 9
  • TypeError exception in AxialPositionalEncoding when using DataParallel

    TypeError exception in AxialPositionalEncoding when using DataParallel

    Hello,

    I want to run SinkhornTransformerLM using multiple GPUs, so I'm wrapping the model into torch.nn.DataParallel. However, when I do this, I get an exception:

    Traceback (most recent call last):
      File "script.py", line 27, in <module>
        model(x)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
        output.reraise()
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
        raise self.exc_type(msg)
    TypeError: Caught TypeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
        output = module(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 792, in forward
        x = self.axial_pos_emb(x) + x
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 243, in forward
        return pos_emb[:, :t]
    TypeError: 'int' object is not subscriptable
    

    Looking at the code, it would seem that self.weights does not get populated. To reproduce this error, I took the first example in README.md and changed

    model(x) # (1, 2048, 20000)
    

    to

    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).to('cuda')
    model(x)
    
    opened by kl0211 8
  • Do I need to pad?

    Do I need to pad?

    Excuse me for the noob question, I have sequences of different lengths and I use:

    model = AutoregressiveWrapper(model, ignore_index=PAD_ID, pad_value=PAD_ID)
    

    Do I need to pad as in this example or not?

    def __getitem__(self, index):
            seq_tokens = self.examples[index]
            input_ids = torch.tensor(seq_tokens, dtype=torch.long)
            input_ids = F.pad(input_ids, (seq_len - len(input_ids), 0), value=self.pad_value)
            return input_ids
    
    opened by timsoraro 8
  • Performance of the different attention variants in the repo

    Performance of the different attention variants in the repo

    Really great work!!

    I have a few questions.

    -Does attention sort net work better or at par with the simple sort net. -Does attention sort net work well for the sortcut variant also. -How does the sortcut variant perform when compared to vanilla sparse sinkhorn attention.

    It would be very helpful if you could share some plots/numbers from your experiments comparing the performance of the different variants such as attention sort net, routing based attention etc.

    opened by jsmith1915 7
  • Some questions about dropout

    Some questions about dropout

    Hi again @lucidrains, I just had some quick questions about dropout with the Sinkhorn Transformer, as I was just using my Linformer implementation (which as you know is based off of this repo), but it was overfitting my dataset. Therefore, I just had some quick questions about some dropout and your implementation, and I wanted to ask whether some design choices here were intentional or not:

    1. In the original Transformer, dropout was performed after each sublayer, before the residual connection. I noticed that you only have this after the SinkhornSelfAttention class, but not after the FeedForward class. Is this intentional?
    2. Speaking of the FeedForward class, you insert dropout after the first linear layer. I couldn't find this anywhere in any literature, were you able to find a reference of why this was effective? I put it into my implementation, and it seems to help, but i just don't know where this idea came from.
    3. On a similar note, do you know why the dots tensor in the self attention classes are dropped out? Again, I put it in my linformer and it seems to work, but I can't find a reference to this in the literature.
    4. Finally, the original transformer also dropped out the input tokens, like so (From the SinkhornTransformerLM class):
        def forward(self, x, **kwargs):
            _, t, device = *x.shape, x.device
            assert t <= self.max_seq_len, f'sequence length {t} is greater than maximum sequence length {self.max_seq_len}'
    
            x = self.to_token_emb(x)
            x = self.axial_pos_emb(x) + x
            """ Dropout would go here"""
            x = self.sinkhorn_transformer(x, **kwargs)
            return self.to_logits(x)
    

    Should they also be dropped out here as well?

    I now updated my repo such that all 4 of these dropout possibilities exist. I'll let you know if this helps overfitting.

    Thank you for your time!

    opened by tatp22 6
  • Training crashes due to inplace operation

    Training crashes due to inplace operation

    Hi there, I am trying to train this model on my custom dataset and training starts but after many iterations, training crashes due to this error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    

    After enabling anomaly detection, here is the error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
    

    Traceback is not useful as it only shows error at line loss.backward() and torch.autograd module. Can you help me where the issue might be?

    Thanks

    opened by NaxAlpha 6
  • Positional Embedding

    Positional Embedding

    Hi! Found out that SinkhornTransformerLM uses two positional encodings simultaneously:

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L778-L779

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L792-L793

    I guess pos_emb can be removed as it introduces memory overhead and makes useless the utilization of Axial Positional Encoding that is designed to reduce the number of positional encoding parameters.

    Is there a special reasoning behind that?

    opened by ilya16 4
  • A wrapper of SinkhornTransformerEncDec

    A wrapper of SinkhornTransformerEncDec

    Could your please coded up a wrapper that removes a lot of the manual work in writing up a generic SinkhornTransformer encoder / decoder architecture.

    Thanks a lot! halexan

    opened by halexan 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Residual2Vec: Debiasing graph embedding using random graphs

Residual2Vec: Debiasing graph embedding using random graphs This repository contains the code for S. Kojaku, J. Yoon, I. Constantino, and Y.-Y. Ahn, R

SADAMORI KOJAKU 5 Oct 12, 2022
Implementation of legal QA system based on SentenceKoBART

LegalQA using SentenceKoBART Implementation of legal QA system based on SentenceKoBART How to train SentenceKoBART Based on Neural Search Engine Jina

Heewon Jeon(gogamza) 75 Dec 27, 2022
The Easy-to-use Dialogue Response Selection Toolkit for Researchers

The Easy-to-use Dialogue Response Selection Toolkit for Researchers

GMFTBY 32 Nov 13, 2022
API for the GPT-J language model 🦜. Including a FastAPI backend and a streamlit frontend

gpt-j-api 🦜 An API to interact with the GPT-J language model. You can use and test the model in two different ways: Streamlit web app at http://api.v

Víctor Gallego 276 Dec 31, 2022
Paddlespeech Streaming ASR GUI

Paddlespeech-Streaming-ASR-GUI Introduction A paddlespeech Streaming ASR GUI. Us

Niek Zhen 3 Jan 05, 2022
Code for "Parallel Instance Query Network for Named Entity Recognition", accepted at ACL 2022.

README Code for Two-stage Identifier: "Parallel Instance Query Network for Named Entity Recognition", accepted at ACL 2022. For details of the model a

Yongliang Shen 45 Nov 29, 2022
"Investigating the Limitations of Transformers with Simple Arithmetic Tasks", 2021

transformers-arithmetic This repository contains the code to reproduce the experiments from the paper: Nogueira, Jiang, Lin "Investigating the Limitat

Castorini 33 Nov 16, 2022
SummerTime - Text Summarization Toolkit for Non-experts

A library to help users choose appropriate summarization tools based on their specific tasks or needs. Includes models, evaluation metrics, and datasets.

Yale-LILY 213 Jan 04, 2023
novel deep learning research works with PaddlePaddle

Research 发布基于飞桨的前沿研究工作,包括CV、NLP、KG、STDM等领域的顶会论文和比赛冠军模型。 目录 计算机视觉(Computer Vision) 自然语言处理(Natrual Language Processing) 知识图谱(Knowledge Graph) 时空数据挖掘(Spa

1.5k Jan 03, 2023
Implementing SimCSE(paper, official repository) using TensorFlow 2 and KR-BERT.

KR-BERT-SimCSE Implementing SimCSE(paper, official repository) using TensorFlow 2 and KR-BERT. Training Unsupervised python train_unsupervised.py --mi

Jeong Ukjae 27 Dec 12, 2022
Chinese version of GPT2 training code, using BERT tokenizer.

GPT2-Chinese Description Chinese version of GPT2 training code, using BERT tokenizer or BPE tokenizer. It is based on the extremely awesome repository

Zeyao Du 5.6k Jan 04, 2023
Nystromformer: A Nystrom-based Algorithm for Approximating Self-Attention

Nystromformer: A Nystrom-based Algorithm for Approximating Self-Attention April 6, 2021 We extended segment-means to compute landmarks without requiri

Zhanpeng Zeng 322 Jan 01, 2023
TalkNet: Audio-visual active speaker detection Model

Is someone talking? TalkNet: Audio-visual active speaker detection Model This repository contains the code for our ACM MM 2021 paper, TalkNet, an acti

142 Dec 14, 2022
Textpipe: clean and extract metadata from text

textpipe: clean and extract metadata from text textpipe is a Python package for converting raw text in to clean, readable text and extracting metadata

Textpipe 298 Nov 21, 2022
💬 Open source machine learning framework to automate text- and voice-based conversations: NLU, dialogue management, connect to Slack, Facebook, and more - Create chatbots and voice assistants

Rasa Open Source Rasa is an open source machine learning framework to automate text-and voice-based conversations. With Rasa, you can build contextual

Rasa 15.3k Dec 30, 2022
Addon for adding subtitle files to blender VSE as Text sequences. Using pysub2 python module.

Import Subtitles for Blender VSE Addon for adding subtitle files to blender VSE as Text sequences. Using pysub2 python module. Supported formats by py

4 Feb 27, 2022
中文問句產生器;使用台達電閱讀理解資料集(DRCD)

Transformer QG on DRCD The inputs of the model refers to we integrate C and A into a new C' in the following form. C' = [c1, c2, ..., [HL], a1, ..., a

Philip 1 Oct 22, 2021
[EMNLP 2021] Mirror-BERT: Converting Pretrained Language Models to universal text encoders without labels.

[EMNLP 2021] Mirror-BERT: Converting Pretrained Language Models to universal text encoders without labels.

Cambridge Language Technology Lab 61 Dec 10, 2022
Py65 65816 - Add support for the 65C816 to py65

Add support for the 65C816 to py65 Py65 (https://github.com/mnaberez/py65) is a

4 Jan 04, 2023
skweak: A software toolkit for weak supervision applied to NLP tasks

Labelled data remains a scarce resource in many practical NLP scenarios. This is especially the case when working with resource-poor languages (or text domains), or when using task-specific labels wi

Norsk Regnesentral (Norwegian Computing Center) 850 Dec 28, 2022