Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch

Overview

Rotary Embeddings - Pytorch

A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Install

$ pip install rotary-embedding-torch

Usage

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

pos_emb = RotaryEmbedding(dim = 32)

# generate the rotations

freqs = pos_emb(torch.arange(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute

# mock queries and keys

q = torch.randn(1, 1024, 64) # queries - (batch, seq len, dimension of head)
k = torch.randn(1, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

freqs = freqs[None, ...] # unsqueeze for batch dimension
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

# then do your attention with your queries (q) and keys (k)

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel'
)

# queries and keys for frequencies to be rotated into

q = torch.randn(1, 256, 256, 64)
k = torch.randn(1, 256, 256, 64)

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)
freqs_w = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Learned Rotations

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...

import torch
from torch import nn
from rotary_embedding_torch import apply_learned_rotations

x = torch.randn(1, 1024, 512)

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = nn.Linear(512, 128)(x)
rots2 = nn.Linear(512, 128)(x)

# you rotate in 256 (128 x 2) at first

x = apply_learned_rotations(rots1, x, start_index = 0)

# then you start at index 256 and rotate in the last (128 x 2)

x = apply_learned_rotations(rots2, x, start_index = 256)

# you could also concat the rotations together and pass it in all at once

rots = torch.cat((rots1, rots2), dim = -1)

x = apply_learned_rotations(rots, x)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.

DEFT: Detection Embeddings for Tracking DEFT: Detection Embeddings for Tracking, Mohamed Chaabane, Peter Zhang, J. Ross Beveridge, Stephen O'Hara

Learning embeddings for classification, retrieval and ranking.
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

Improving XGBoost survival analysis with embeddings and debiased estimators
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Reliable probability face embeddings
Reliable probability face embeddings

ProbFace, arxiv This is a demo code of training and testing [ProbFace] using Tensorflow. ProbFace is a reliable Probabilistic Face Embeddging (PFE) me

 UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

🤖 A Python library for learning and evaluating knowledge graph embeddings
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Comments
  • Custom position offset when rotating queries or keys

    Custom position offset when rotating queries or keys

    This library seems to assume that queries and keys are left-aligned position-wise e.g.

    q = [p_0, p_1, p_2]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

    q =           [p_2, p_3, p_4]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example

    import torch
    from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding
    
    rot = RotaryEmbedding(dim=32)
    
    q = torch.ones(1, 8, 4, 32)
    k = torch.ones(1, 8, 6, 32)
    
    q = q / torch.norm(q, dim=-1, keepdim=True)
    k = k / torch.norm(k, dim=-1, keepdim=True)
    
    q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    prints the following relative position embedding

    tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
            [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
            [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])
    

    (diagonal of 1s right-aligned) whereas the default behavior

    ...
    
    q_rot = rot.rotate_queries_or_keys(q)
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    would print

    tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
            [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
            [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])
    

    (diagonal of 1s left-aligned).

    opened by krasserm 1
  • about axial rotary embeddings

    about axial rotary embeddings

    Hi, Thank you for sharing this code with us. However, I was confused with the axial rotary embeddings in rotary_embedding_torch.py file. " elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi " Where does this formula come from?What parameter is max_freqs?Why the freqs is not " 1/(10000^(2i/d))"?

    Thank you again.

    opened by raindrop313 0
Owner
Phil Wang
Working with Attention
Phil Wang
Multi-Joint dynamics with Contact. A general purpose physics simulator.

MuJoCo Physics MuJoCo stands for Multi-Joint dynamics with Contact. It is a general purpose physics engine that aims to facilitate research and develo

DeepMind 5.2k Jan 02, 2023
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
Spatially-Adaptive Pixelwise Networks for Fast Image Translation, CVPR 2021

Image Translation with ASAPNets Spatially-Adaptive Pixelwise Networks for Fast Image Translation, CVPR 2021 Webpage | Paper | Video Installation insta

Tamar Rott Shaham 100 Dec 28, 2022
Pairwise learning neural link prediction for ogb link prediction

Pairwise Learning for Neural Link Prediction for OGB (PLNLP-OGB) This repository provides evaluation codes of PLNLP for OGB link property prediction t

Zhitao WANG 31 Oct 10, 2022
N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting

N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting Recent progress in neural forecasting instigated significant improvements in the

Cristian Challu 82 Jan 04, 2023
User-friendly bulk RNAseq deconvolution using simulated annealing

Welcome to cellanneal - The user-friendly application for deconvolving omics data sets. cellanneal is an application for deconvolving biological mixtu

11 Dec 16, 2022
Implementation for Curriculum DeepSDF

Curriculum-DeepSDF This repository is an implementation for Curriculum DeepSDF. Full paper is available here. Preparation Please follow original setti

Haidong Zhu 69 Dec 29, 2022
Rethinking of Pedestrian Attribute Recognition: A Reliable Evaluation under Zero-Shot Pedestrian Identity Setting

Pytorch Pedestrian Attribute Recognition: A strong PyTorch baseline of pedestrian attribute recognition and multi-label classification.

Jian 79 Dec 18, 2022
PyTorch CZSL framework containing GQA, the open-world setting, and the CGE and CompCos methods.

Compositional Zero-Shot Learning This is the official PyTorch code of the CVPR 2021 works Learning Graph Embeddings for Compositional Zero-shot Learni

EML Tübingen 70 Dec 27, 2022
The code for paper "Learning Implicit Fields for Generative Shape Modeling".

implicit-decoder The tensorflow code for paper "Learning Implicit Fields for Generative Shape Modeling", Zhiqin Chen, Hao (Richard) Zhang. Project pag

Zhiqin Chen 353 Dec 30, 2022
Library to enable Bayesian active learning in your research or labeling work.

Bayesian Active Learning (BaaL) BaaL is an active learning library developed at ElementAI. This repository contains techniques and reusable components

ElementAI 687 Dec 25, 2022
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
Code accompanying "Learning What To Do by Simulating the Past", ICLR 2021.

Learning What To Do by Simulating the Past This repository contains code that implements the Deep Reward Learning by Simulating the Past (Deep RSLP) a

Center for Human-Compatible AI 24 Aug 07, 2021
Coarse implement of the paper "A Simultaneous Denoising and Dereverberation Framework with Target Decoupling", On DNS-2020 dataset, the DNSMOS of first stage is 3.42 and second stage is 3.47.

SDDNet Coarse implement of the paper "A Simultaneous Denoising and Dereverberation Framework with Target Decoupling", On DNS-2020 dataset, the DNSMOS

Cyril Lv 43 Nov 21, 2022
FlowTorch is a PyTorch library for learning and sampling from complex probability distributions using a class of methods called Normalizing Flows

FlowTorch is a PyTorch library for learning and sampling from complex probability distributions using a class of methods called Normalizing Flows.

Meta Incubator 272 Jan 02, 2023
A simple python program that can be used to implement user authentication tokens into your program...

token-generator A simple python module that can be used by developers to implement user authentication tokens into your program... code examples creat

octo 6 Apr 18, 2022
A set of simple scripts to process the Imagenet-1K dataset as TFRecords and make index files for NVIDIA DALI.

Overview This is a set of simple scripts to process the Imagenet-1K dataset as TFRecords and make index files for NVIDIA DALI. Make TFRecords To run t

8 Nov 01, 2022
PlenOctrees: NeRF-SH Training & Conversion

PlenOctrees Official Repo: NeRF-SH training and conversion This repository contains code to train NeRF-SH and to extract the PlenOctree, constituting

Alex Yu 323 Dec 29, 2022
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

Phil Wang 208 Dec 25, 2022