Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning

Overview

H-Transformer-1D

Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs.

For now, the H-Transformer will only act as a long-context encoder

Install

$ pip install h-transformer-1d

Usage

import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
    num_tokens = 256,          # number of tokens
    dim = 512,                 # dimension
    depth = 2,                 # depth
    max_seq_len = 8192,        # maximum sequence length
    heads = 8,                 # heads
    dim_head = 64,             # dimension per head
    block_size = 128           # block size
)

x = torch.randint(0, 256, (1, 8000))   # variable sequence length
mask = torch.ones((1, 8000)).bool()    # variable mask length

# network will automatically pad to power of 2, do hierarchical attention, etc

logits = model(x, mask = mask) # (1, 8000, 256)

Citations

@misc{zhu2021htransformer1d,
    title   = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences}, 
    author  = {Zhenhai Zhu and Radu Soricut},
    year    = {2021},
    eprint  = {2107.11906},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Masking not working in training, thanks

    Masking not working in training, thanks

    Hi, I have tried to train the model on GPU with masking enabled. The line 94 t = torch.flip(t, dims = (2,)) reports an error: RuntimeError: "flip_cuda" not implemented for 'Bool', even though I have tried to move mask to CPU.

    Any ideas to solve the problem? Thanks a lot.

    opened by junyongyou 6
  • Application to sequence classification?

    Application to sequence classification?

    Hi,

    Forgive the naive question, I am trying to make sense of this paper but it's tough going. If I understand correctly, this attention mechanism focuses mainly on nearby tokens and only attends to distant tokens via a hierarchical, low-rank approximation. In that case, can the usual sequence classification approach of having a global [CLS] token that can attend to all other tokens (and vice versa) still work? If not, how can this attention mechanism handle the text classification tasks in the long range arena benchmark?

    Cheers for whatever insights you can share, and thanks for the great work!

    opened by trpstra 4
  • eos token does not work in batch mode generation

    eos token does not work in batch mode generation

    When generating the sequence with current code it seems the eos_token will work when generating one sequence at a time https://github.com/lucidrains/h-transformer-1d/blob/main/h_transformer_1d/autoregressive_wrapper.py#L59

    opened by tmphex 4
  • RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    RuntimeError: Tensor type unknown to einops

    lib/python3.7/site-packages/einops/_backends.py", line 52, in get_backend raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))

    RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    I understand why this gets raised. Could it be a pytorch version problem? Mine is 1.6.0

    opened by wajihullahbaig 4
  • Algorithm Mismatch

    Algorithm Mismatch

    Paper Implementation

    In the implementation, we get blocked Q, K, V tensors by level with the code below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L164-L179

    And return the final result of matrix-matrix product with Equation 29 or 69 with the for loop below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L234-L247

    What is problem?

    However, according to the current code, it is not possible to include information about the level 0 white blocks in the figure below. (Equation 70 of the paper includes the corresponding attention matrix entries.)

    fig2

    I think you should also add an off-diagonal term of near-interaction (level 0) to match Equation 69!

    opened by jinmang2 3
  • I have some questions about implementation details

    I have some questions about implementation details

    Thanks for making your implementation public. I have three questions about your h-transformer 1d implementation.

    1. The number of levels M

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L105-L114

    In papers, eq (32) gives a guide on how M is determined.

    img1

    In your code implementations,

    • n is sequence length which is not padded
    • bsz is block size (is same to N_r which is numerical rank of the off-diagonal blocks)
    • Because code line 111 already contains level 0, M is equal to int(log2(n // bsz)) - 1

    However, in the Section 6.1 Constructing Hierarchical Attention, I found that sequence length(L) must be a multiple of 2. In my opinion, eq (31)'s L is equal to 2^{M+1}. In implementation, n is not padded sequence. So one of M is missing.

    Since x is the sequence padded by processing below, https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L83-L91

    I think the above implementation should be modified as below

    107    num_levels = int(log2(x.size(1) // bsz)) - 1 
    

    2. Super- and Sub-diagonal blocks of the coarsened matrix \tilde{A} as level-l

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L190-L198

    Ys conatins y and A computed as calculate_Y_and_A. For examples,

    # Ys
    [
        (batch_size*n_heads, N_b(2), N_r), (batch_size*n_heads, N_b(2)),  # level 2, (Y(2), tilde_A(2))
        (batch_size*n_heads, N_b(1), N_r), (batch_size*n_heads, N_b(1)),  # level 1, (Y(1), tilde_A(1))
        (batch_size*n_heads, N_b(0), N_r), (batch_size*n_heads, N_b(0)),  # level 0, (Y(0), tilde_A(0))
    ]
    

    In eq (29), Y is calculated as Y = AV = Y(0) + P(0)( Y(1) + P(1)Y(2) ) However, in code line 190, Y is calculated using only level-0 and level-1 blocks, no matter how many M there are. Y = AV = Y(0) + P(0)Y(1)

    Does increasing the level cause performance degradation issues in implementation? I'm so curious!


    3. Comparison with Luna: Linear Unified Nested Attention

    h-transformer significantly exceeded the scores of BigBird and Luna in LRA. However, what I regretted while reading the paper was that there was no comparison of computation time with other sub-quadratic and Luna. Is this algorithm much faster than other sub-quadratic? And how about compared to Luna?


    Thanks again for the implementation release! The idea of ​​calculating off-diagonal with flip was amazing and I learned a lot. Thank you!! 😄

    opened by jinmang2 3
  • Add Norm Missing

    Add Norm Missing

    I am using code now, and i wonder is there implemented add norm? I only find layer norm, but no add operation. Here is code in h-transformer-1d.py line 489 ... Is this a bug or something ? Thanks @Lucidrains

    for ind in range(depth): attn = attn_class(dim, dim_head = dim_head, heads = heads, block_size = block_size, pos_emb = self.pos_emb, **attn_kwargs) ff = FeedForward(dim, mult = ff_mult)

            if shift_tokens:
                attn, ff = map(lambda t: PreShiftTokens(shift_token_ranges, t), (attn, ff))
    
            attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff))
            layers.append(nn.ModuleList([attn ,ff]))_
    
    opened by wwx13 2
  • Mask not working

    Mask not working

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device
        assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length'
        x = self.token_emb(x)
        x = self.layers(x)
        return self.to_logits(x)
    

    I think... Masking does not work ???

    opened by wwx13 2
  • One simple question

    One simple question

    Hi, Phil!

    One simple question, (my math is not good) https://github.com/lucidrains/h-transformer-1d/blob/7c11d036d53926495ec0917a34a1aad7261892b5/train.py#L65

    why not be randint(0, self.data.size(0)-self.seq_len+1)? Since the high part should be excluded

    opened by CiaoHe 2
  • Mini-batching (b > 1) does not work with masking

    Mini-batching (b > 1) does not work with masking

    When using x and mask that have batch size larger than 1 following error is arises:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    batch_size = 2
    x = torch.randint(0, 256, (batch_size, 8000))   # variable sequence length
    mask = torch.ones((batch_size, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives following error:

    ~/git/h-transformer-1d/h_transformer_1d/h_transformer_1d.py in masked_aggregate(tensor, mask, dim, average)
         19     diff_len = len(tensor.shape) - len(mask.shape)
         20     mask = mask[(..., *((None,) * diff_len))]
    ---> 21     tensor = tensor.masked_fill(~mask, 0.)
         22 
         23     total_el = mask.sum(dim = dim)
    
    RuntimeError: The size of tensor a (2) must match the size of tensor b (16) at non-singleton dimension 0
    

    It seems the tensor has shape heads * batch in 0 dimension and not batch what mask has.

    opened by jaak-s 2
  • Example in README does not work

    Example in README does not work

    Executing the example:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    x = torch.randint(0, 256, (1, 8000))   # variable sequence length
    mask = torch.ones((1, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives the following error:

    ~/miniconda3/lib/python3.7/site-packages/rotary_embedding_torch/rotary_embedding_torch.py in apply_rotary_emb(freqs, t, start_index)
         43     assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
         44     t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    ---> 45     t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
         46     return torch.cat((t_left, t, t_right), dim = -1)
         47 
    
    RuntimeError: The size of tensor a (8192) must match the size of tensor b (8000) at non-singleton dimension 1
    
    opened by jaak-s 2
  • Fix indexing

    Fix indexing

    I am fixing a few apparent bugs in the code. The upshot is that the attention now supports a block size of the (next largest power of two) of the input length, and for this value of the block size it becomes exact. This allows one to look at the systematic error in the output as a function of decreased block size (and memory usage).

    I've found this module to reduce memory consumption by a factor of two, but the approximation quickly becomes too inaccurate with decreasing block size to use it as a drop-in replacement for an existing (full) attention layer.

    This repository shows how to compute the full attention with linear memory complexity: https://github.com/CHARM-Tx/linear_mem_attention_pytorch

    opened by jglaser 0
  • Approximated values are off

    Approximated values are off

    I wrote a simple test to check the output of the hierarchical transformer self attention against the BERT self attention from huggingface transformers.

    import torch
    import torch.nn as nn
    import math
    
    from h_transformer_1d.h_transformer_1d import HAttention1D
    
    def transpose_for_scores(x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def bert_self_attention(query, key, value_layer, attention_mask=None, num_attention_heads=1):
            dim_head = query.size()[-1] // num_attention_heads
            all_head_size = dim_head*num_attention_heads
    
            query_layer = transpose_for_scores(query, num_attention_heads, dim_head)
            key_layer = transpose_for_scores(key, num_attention_heads, dim_head)
            value_layer = transpose_for_scores(value, num_attention_heads, dim_head)
    
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    
            attention_scores = attention_scores / math.sqrt(dim_head)
    
            if attention_mask is not None:
                # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
                attention_scores = attention_scores + attention_mask
    
            # Normalize the attention scores to probabilities.
            attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    
            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            #attention_probs = self.dropout(attention_probs)
    
            context_layer = torch.matmul(attention_probs, value_layer)
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
    
            return context_layer, attention_probs
    
    if __name__ == "__main__":
        query = torch.tensor([[[0.1,0.2],[-0.5,0.7],[-0.5,-0.75],[.123,.456]]])
    #    query = torch.tensor([[[0.1,0.2],[-0.5,0.7]]])
        key = value = query
    
        n_heads = 1
        attn, probs = bert_self_attention(query, key, value, num_attention_heads=n_heads)
        print('bert_self_attention out: ', attn)
    
        block_size = 1
        for _ in range(0,2):
            dim_head = query.size()[-1]//n_heads
            h_attn = HAttention1D(
                dim=query.size()[-1],
                heads=n_heads,
                dim_head=dim_head,
                block_size=block_size
            )
    
            h_attn.to_qkv = torch.nn.Identity()
            h_attn.to_out = torch.nn.Identity()
    
            qkv = torch.stack([query, key, value], dim=2)
            qkv = torch.flatten(qkv, start_dim=2)
    
            attn_scores = h_attn(qkv)
            print('hattention_1d: (block_size = {})'.format(block_size), attn_scores)
    
            block_size *= 2
    

    This is the output I get

    bert_self_attention:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2000,  0.4500],
             [-0.2000,  0.4500],
             [-0.1885, -0.1470],
             [-0.1885, -0.1470]]])
    

    before it errors out with

    assert num_levels >= 0, 'number of levels must be at least greater than 0'
    

    Some of the values are off in absolute magnitude by more than a factor of two.

    Looking at the code, this line seems problematic: https://github.com/lucidrains/h-transformer-1d/blob/8afd75cc6bc41754620bb6ab3737176cb69bdf93/h_transformer_1d/h_transformer_1d.py#L172

    I believe it should read

    num_levels = int(log2(pad_to_len // bsz)) - 1
    

    If I make that change, the approximated attention output is much closer to the exact one:

    bert_self_attention out:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020]]])
    hattention_1d: (block_size = 2) tensor([[[-0.1808,  0.1972],
             [-0.1980,  0.2314],
             [-0.2438,  0.0910],
             [-0.1719,  0.2413]]])
    
    opened by jglaser 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implementation of PersonaGPT Dialog Model

PersonaGPT An open-domain conversational agent with many personalities PersonaGPT is an open-domain conversational agent cpable of decoding personaliz

ILLIDAN Lab 42 Jan 01, 2023
Official repository of "DeepMIH: Deep Invertible Network for Multiple Image Hiding", TPAMI 2022.

DeepMIH: Deep Invertible Network for Multiple Image Hiding (TPAMI 2022) This repo is the official code for DeepMIH: Deep Invertible Network for Multip

Junpeng Jing 67 Nov 22, 2022
Detect roadway lanes using Python OpenCV for project during the 5th semester at DHBW Stuttgart for lecture in digital image processing.

Find Line Detection (Image Processing) Identifying lanes of the road is very common task that human driver performs. It's important to keep the vehicl

LMF 4 Jun 21, 2022
A novel Engagement Detection with Multi-Task Training (ED-MTT) system

A novel Engagement Detection with Multi-Task Training (ED-MTT) system which minimizes MSE and triplet loss together to determine the engagement level of students in an e-learning environment.

Onur Çopur 12 Nov 11, 2022
SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model

SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model Edresson Casanova, Christopher Shulby, Eren Gölge, Nicolas Michael Müller, Frede

Edresson Casanova 92 Dec 09, 2022
A fast python implementation of Ray Tracing in One Weekend using python and Taichi

ray-tracing-one-weekend-taichi A fast python implementation of Ray Tracing in One Weekend using python and Taichi. Taichi is a simple "Domain specific

157 Dec 26, 2022
Monify: an Expense tracker Program implemented in a Graphical User Interface that allows users to keep track of their expenses

💳 MONIFY (EXPENSE TRACKER PRO) 💳 Description Monify is an Expense tracker Program implemented in a Graphical User Interface allows users to add inco

Moyosore Weke 1 Dec 14, 2021
An Artificial Intelligence trying to drive a car by itself on a user created map

An Artificial Intelligence trying to drive a car by itself on a user created map

Akhil Sahukaru 17 Jan 13, 2022
Official PyTorch implementation of Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations

Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations Zhenyu Jiang, Yifeng Zhu, Maxwell Svetlik, Kuan Fang, Yu

UT-Austin Robot Perception and Learning Lab 63 Jan 03, 2023
Data-depth-inference - Data depth inference with python

Welcome! This readme will guide you through the use of the code in this reposito

Marco 3 Feb 08, 2022
Monitor your ML jobs on mobile devices📱, especially for Google Colab / Kaggle

TF Watcher TF Watcher is a simple to use Python package and web app which allows you to monitor 👀 your Machine Learning training or testing process o

Rishit Dagli 54 Nov 01, 2022
A demo of how to use JAX to create a simple gravity simulation

JAX Gravity This repo contains a demo of how to use JAX to create a simple gravity simulation. It uses JAX's experimental ode package to solve the dif

Cristian Garcia 16 Sep 22, 2022
magiCARP: Contrastive Authoring+Reviewing Pretraining

magiCARP: Contrastive Authoring+Reviewing Pretraining Welcome to the magiCARP API, the test bed used by EleutherAI for performing text/text bi-encoder

EleutherAI 43 Dec 29, 2022
Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localization and Semantic Segmentation (CVPR 2022)

CCAM (Unsupervised) Code repository for our paper "CCAM: Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localizati

Computer Vision Insitute, SZU 113 Dec 27, 2022
TensorFlow-based implementation of "ICNet for Real-Time Semantic Segmentation on High-Resolution Images".

ICNet_tensorflow This repo provides a TensorFlow-based implementation of paper "ICNet for Real-Time Semantic Segmentation on High-Resolution Images,"

HsuanKung Yang 406 Nov 27, 2022
Training Very Deep Neural Networks Without Skip-Connections

DiracNets v2 update (January 2018): The code was updated for DiracNets-v2 in which we removed NCReLU by adding per-channel a and b multipliers without

Sergey Zagoruyko 585 Oct 12, 2022
BookMyShowPC - Movie Ticket Reservation App made with Tkinter

Book My Show PC What is this? Movie Ticket Reservation App made with Tkinter. Tk

The Nithin Balaji 3 Dec 09, 2022
This project provides an unsupervised framework for mining and tagging quality phrases on text corpora with pretrained language models (KDD'21).

UCPhrase: Unsupervised Context-aware Quality Phrase Tagging To appear on KDD'21...[pdf] This project provides an unsupervised framework for mining and

Xiaotao Gu 146 Dec 22, 2022
Trained on Simulated Data, Tested in the Real World

Trained on Simulated Data, Tested in the Real World

livox 43 Nov 18, 2022
Non-stationary GP package written from scratch in PyTorch

NSGP-Torch Examples gpytorch model with skgpytorch # Import packages import torch from regdata import NonStat2D from gpytorch.kernels import RBFKernel

Zeel B Patel 1 Mar 06, 2022