PyTorch implementation of Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.

Overview

ALiBi

PyTorch implementation of Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.

alt-text

Quickstart

Clone this repository.

git clone https://github.com/jaketae/alibi.git

Navigate to the cloned directory. You can use the bare-bone ALiBi decoder via

>>> import torch; from alibi import ALiBiConfig, ALiBiTransformer
>>> config  = ALiBiConfig()
>>> model = ALiBiTransformer(config)
>>> x = torch.randn(8, 100, 256)
>>> model(x).shape
torch.Size([8, 100, 256])

By default, the model comes with the following parameters:

ALiBiConfig(
    num_layers=6, 
    d_model=256, 
    num_heads=8, 
    max_len=256, 
    dropout=0.1, 
    causal=True, 
    expansion_factor=1
)

To use an encoder instead of a decoder, simply toggle causal=False.

Abstract

Since the introduction of the transformer model by Vaswani et al. (2017), a fundamental question remains open: how to achieve extrapolation at inference time to longer sequences than seen during training? We first show that extrapolation can be improved by changing the position representation method, though we find that existing proposals do not allow efficient extrapolation. We introduce a simple and efficient method, Attention with Linear Biases (ALiBi), that allows for extrapolation. ALiBi does not add positional embeddings to the word embeddings; instead, it biases the query-key attention scores with a term that is proportional to their distance. We show that this method allows training a 1.3 billion parameter model on input sequences of length 1024 that extrapolates to input sequences of length 2048, achieving the same perplexity as a sinusoidal position embedding model trained on inputs of length 2048, 11% faster and using 11% less memory. ALiBi's inductive bias towards recency allows it to outperform multiple strong position methods on the WikiText-103 benchmark. Finally, we provide analysis of ALiBi to understand why it leads to better performance.

Citation

@misc{press2021train,
	title        = {Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation},
	author       = {Ofir Press and Noah A. Smith and Mike Lewis},
	year         = 2021,
	eprint       = {2108.12409},
	archiveprefix = {arXiv},
	primaryclass = {cs.CL}
}
Owner
Jake Tae
CS + Math @ Yale, SWE intern @huggingface
Jake Tae
NHS AI Lab Skunkworks project: Long Stayer Risk Stratification

NHS AI Lab Skunkworks project: Long Stayer Risk Stratification A pilot project for the NHS AI Lab Skunkworks team, Long Stayer Risk Stratification use

NHSX 21 Nov 14, 2022
GestureSSD CBAM - A gesture recognition web system based on SSD and CBAM, using pytorch, flask and node.js

GestureSSD_CBAM A gesture recognition web system based on SSD and CBAM, using pytorch, flask and node.js SSD implementation is based on https://github

xue_senhua1999 2 Jan 06, 2022
CARL provides highly configurable contextual extensions to several well-known RL environments.

CARL (context adaptive RL) provides highly configurable contextual extensions to several well-known RL environments.

AutoML-Freiburg-Hannover 51 Dec 28, 2022
Energy consumption estimation utilities for Jetson-based platforms

This repository contains a utility for measuring energy consumption when running various programs in NVIDIA Jetson-based platforms. Currently TX-2, NX, and AGX are supported.

OpenDR 10 Jun 17, 2022
This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Developed By Google!

Machine Learning Hand Detector This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Dev

Popstar Idhant 3 Feb 25, 2022
Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis"

Beyond the Spectrum Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis" by Yang He, Ning Yu, Margret Keu

Yang He 27 Jan 07, 2023
PyTorch implementation of ENet

PyTorch-ENet PyTorch (v1.1.0) implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation, ported from the lua-torc

David Silva 333 Dec 29, 2022
On the adaptation of recurrent neural networks for system identification

On the adaptation of recurrent neural networks for system identification This repository contains the Python code to reproduce the results of the pape

Marco Forgione 3 Jan 13, 2022
Combine Tacotron2 and Hifi GAN to generate speech from text

EndToEndTextToSpeech Combine Tacotron2 and Hifi GAN to generate speech from text Download weights Hifi GAN - hifi_gan/checkpoint/ : pretrain 2.5M ste

Phạm Quốc Huy 1 Dec 18, 2021
🔮 Execution time predictions for deep neural network training iterations across different GPUs.

Habitat: A Runtime-Based Computational Performance Predictor for Deep Neural Network Training Habitat is a tool that predicts a deep neural network's

Geoffrey Yu 44 Dec 27, 2022
Ranger - a synergistic optimizer using RAdam (Rectified Adam), Gradient Centralization and LookAhead in one codebase

Ranger-Deep-Learning-Optimizer Ranger - a synergistic optimizer combining RAdam (Rectified Adam) and LookAhead, and now GC (gradient centralization) i

Less Wright 1.1k Dec 21, 2022
3D mesh stylization driven by a text input in PyTorch

Text2Mesh [Project Page] Text2Mesh is a method for text-driven stylization of a 3D mesh, as described in "Text2Mesh: Text-Driven Neural Stylization fo

Threedle (University of Chicago) 649 Dec 27, 2022
PyTorch implementation of our method for adversarial attacks and defenses in hyperspectral image classification.

Self-Attention Context Network for Hyperspectral Image Classification PyTorch implementation of our method for adversarial attacks and defenses in hyp

22 Dec 02, 2022
Public Models considered for emotion estimation from EEG

Emotion-EEG Set of models for emotion estimation from EEG. Composed by the combination of two deep-learing models learning together (RNN and CNN) with

Victor Delvigne 21 Dec 23, 2022
Eff video representation - Efficient video representation through neural fields

Neural Residual Flow Fields for Efficient Video Representations 1. Download MPI

41 Jan 06, 2023
A tool to estimate time varying instantaneous reproduction number during epidemics

EpiEstim A tool to estimate time varying instantaneous reproduction number during epidemics. It is described in the following paper: @article{Cori2013

MRC Centre for Global Infectious Disease Analysis 78 Dec 19, 2022
A 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for no

Kakao Brain 604 Dec 14, 2022
Code for "Learning Graph Cellular Automata"

Learning Graph Cellular Automata This code implements the experiments from the NeurIPS 2021 paper: "Learning Graph Cellular Automata" Daniele Grattaro

Daniele Grattarola 37 Oct 26, 2022
Implementations of orthogonal and semi-orthogonal convolutions in the Fourier domain with applications to adversarial robustness

Orthogonalizing Convolutional Layers with the Cayley Transform This repository contains implementations and source code to reproduce experiments for t

CMU Locus Lab 36 Dec 30, 2022