[NeurIPS 2020] Official Implementation: "SMYRF: Efficient Attention using Asymmetric Clustering".

Related tags

Deep Learningsmyrf
Overview

SMYRF: Efficient attention using asymmetric clustering

Get started:

Colab

Abstract

We propose a novel type of balanced clustering algorithm to approximate attention. Attention complexity is reduced from O(N^2) to O(NlogN), where N is the sequence length. Our algorithm, SMYRF, uses Locality Sensitive Hashing (LSH) in a novel way by defining new Asymmetric transformations and an adaptive scheme that produces balanced clusters. The biggest advantage of SMYRF is that it can be used as a drop-in replacement for dense attention layers without any retraining. On the contrary, prior fast attention methods impose constraints (e.g. tight queries and keys) and require re-training from scratch. We apply our method to pre-trained state-of-the-art Natural Language Processing and Computer Vision models and we report significant memory and speed benefits. Notably, SMYRF-BERT outperforms (slightly) BERT on GLUE, while using $50%$ less memory. We also show that SMYRF can be used interchangeably with dense attention before and after training. Finally, we use SMYRF to train GANs with attention in high resolutions. Using a single TPU, we train BigGAN on Celeba-HQ, with attention at resolution 128x128 and 256x256, capable of generating realistic human faces.

Authors: Giannis Daras, Nikita Kitaev, Augustus Odena, Alexandros G. Dimakis

Results

Memory-quality trade-off

GLUE benchmark

Avg. # C CoLA MNLI-m/mm MRPC QNLI QQP RTE SST-2 STS-B
BERT128 82.69 1 1 57.83 84.43/84.68 88.41 91.31 89.70 65.70 93.46 88.73
SMYRF-BERT2x32 82.98 2 32 58.79 83.76/84.27 87.69 91.14 89.72 68.59 93.23 89.65
SMYRF-BERT2x16 81.74 2 16 58.90 82.86/83.49 85.72 89.53 89.33 64.98 93.12 87.75
BERT64 81.57 1 64 58.80 82.34/82.47 87.02 90.48 89.69 61.73 93.00 88.64
BERT32 73.56 1 32 56.40 64.51/63.41 77.89 79.81 88.59 55.23 92.66 83.53

Interchangeability of SMYRF and dense attention

Results on IMDB dataset. Using dense attention on inference consistently improves results, nearly matching dense attention perf.

Memory SMYRF Inference Accuracy
RoBERTa 100% 94.96%
SMYRF-RoBERTa 50% 93.72%
SMYRF-RoBERTa 50% 94.62%
BERT 100% 94.12%
SMYRF-BERT 50% 92.64%
SMYRF-BERT 50% 93.54%

Smyrf-BigGAN training on Celeba-HQ-128

Generated faces by a Smyrf-BigGAN trained on 128x128 resolution with attention at 128x128, using 50% of dense memory.

Results after 120k iterations:

Resolution Attention # C FID
BigGAN 128x128 64x64 1 4096 26.06
Smyrf-BigGAN 128x128 128x128 4 2048 25.03

where # denotes number of hashes and C number of queries per cluster.

What's here

The code hosted in this repository is the one we used to run all the experiments in the paper. Get started:

Colab

For a deeper dive, look at the examples/ folder where we have code for pre-training SMYRF-BigGAN, sampling from a pre-trained BigGAN with SMYRF, finetuning state-of-the-art NLP models with SMYRF and a lot more.

Acknowledgments

We would like to wholeheartedly thank the TensorFlow Research Cloud (TFRC) program that gave us access to Cloud TPUs and GCP credits to train our models.

The code for the NLP experiments is exclusively based on the HuggingFace transformers library. We are very grateful to the authors of the library for their work.

The code for the CV experiments is based on the PyTorch implementation of BigGAN available in this url. The code has been expanded to support training on TPUs. Again, we want to thank the author for open-sourcing this implementation.

You might also like...
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Code for Discriminative Sounding Objects Localization (NeurIPS 2020)
Code for Discriminative Sounding Objects Localization (NeurIPS 2020)

Discriminative Sounding Objects Localization Code for our NeurIPS 2020 paper Discriminative Sounding Objects Localization via Self-supervised Audiovis

Advances in Neural Information Processing Systems (NeurIPS), 2020.

What is being transferred in transfer learning? This repo contains the code for the following paper: Behnam Neyshabur*, Hanie Sedghi*, Chiyuan Zhang*.

Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)
Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Neuron Merging: Compensating for Pruned Neurons Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference

Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Defending graph neural networks against adversarial attacks (NeurIPS 2020)
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ([email protected]), Marinka Zitnik ([email protected].

Code for the Population-Based Bandits Algorithm, presented at NeurIPS 2020.

Population-Based Bandits (PB2) Code for the Population-Based Bandits (PB2) Algorithm, from the paper Provably Efficient Online Hyperparameter Optimiza

Code release for NeurIPS 2020 paper "Co-Tuning for Transfer Learning"

CoTuning Official implementation for NeurIPS 2020 paper Co-Tuning for Transfer Learning. [News] 2021/01/13 The COCO 70 dataset used in the paper is av

Discovering Interpretable GAN Controls [NeurIPS 2020]
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Comments
  • Auto-regressive

    Auto-regressive

    Hi Giannis!

    Thanks for the great paper! I am interested in your asymmetric LSH, as I think having separate query / key space (as opposed to shared QK as in Reformer) will bring performance improvements in LSH-based attention.

    I saw that you recommended to a previous user to use this form of clustering for the auto-regressive case, and just wanted to probe if you had considered the scenario where a bucket of queries do not get matched with any keys from the past at all. This was an issue I had with trying to make separate QK space work with routing transformer, but just wondering if you had identified and found a solution to this problem.

    Phil

    opened by lucidrains 2
  • Logging and scoring

    Logging and scoring

    Currently logging and scoring is disabled for TPU BigGAN for maximum efficiency. We can probably re-write the logger and scorer to lower their performance bottleneck by converting most cpu materializations to XLA ops.

    bug example 
    opened by giannisdaras 0
  • Ema not working on TPU

    Ema not working on TPU

    Exponential moving average on weights of G is not working on TPUs. The problem is related to the loading of the state dict: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L614

    For now, we disable ema.

    bug example 
    opened by giannisdaras 0
Releases(1.0)
Owner
Giannis Daras
Machine Learning Researcher. Ph.D. student, UT Austin.
Giannis Daras
PyTorch implementation of our ICCV paper DeFRCN: Decoupled Faster R-CNN for Few-Shot Object Detection.

Introduction This repo contains the official PyTorch implementation of our ICCV paper DeFRCN: Decoupled Faster R-CNN for Few-Shot Object Detection. Up

133 Dec 29, 2022
scikit-learn: machine learning in Python

scikit-learn is a Python module for machine learning built on top of SciPy and is distributed under the 3-Clause BSD license. The project was started

scikit-learn 52.5k Jan 08, 2023
SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data

SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data Au

14 Nov 28, 2022
CUda Matrix Multiply library.

cumm CUda Matrix Multiply library. cumm is developed during learning of CUTLASS, which use too much c++ template and make code unmaintainable. So I de

49 Dec 27, 2022
Vehicle direction identification consists of three module detection , tracking and direction recognization.

Vehicle-direction-identification Vehicle direction identification consists of three module detection , tracking and direction recognization. Algorithm

5 Nov 15, 2022
MatchGAN: A Self-supervised Semi-supervised Conditional Generative Adversarial Network

MatchGAN: A Self-supervised Semi-supervised Conditional Generative Adversarial Network This repository is the official implementation of MatchGAN: A S

Justin Sun 12 Dec 27, 2022
🚀 An end-to-end ML applications using PyTorch, W&B, FastAPI, Docker, Streamlit and Heroku

🚀 An end-to-end ML applications using PyTorch, W&B, FastAPI, Docker, Streamlit and Heroku

Made With ML 82 Jun 26, 2022
Retinal vessel segmentation based on GT-UNet

Retinal vessel segmentation based on GT-UNet Introduction This project is a retinal blood vessel segmentation code based on UNet-like Group Transforme

Kent0n 27 Dec 18, 2022
An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning

Mammoth - An Extendible (General) Continual Learning Framework for Pytorch NEWS STAY TUNED: We are working on an update of this repository to include

AImageLab 277 Dec 28, 2022
Atif Hassan 103 Dec 14, 2022
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

CIFAR-10_train-test - training and testing codes for dataset CIFAR-10

Frederick Wang 3 Apr 26, 2022
MLSpace: Hassle-free machine learning & deep learning development

MLSpace: Hassle-free machine learning & deep learning development

abhishek thakur 293 Jan 03, 2023
Official implementation of Pixel-Level Bijective Matching for Video Object Segmentation

BMVOS This is the official implementation of Pixel-Level Bijective Matching for Video Object Segmentation, to appear in WACV 2022. @article{cho2021pix

Suhwan Cho 13 Dec 14, 2022
Mouse Brain in the Model Zoo

Deep Neural Mouse Brain Modeling This is the repository for the ongoing deep neural mouse modeling project, an attempt to characterize the representat

Colin Conwell 15 Aug 22, 2022
[CVPR 2021] Anycost GANs for Interactive Image Synthesis and Editing

Anycost GAN video | paper | website Anycost GANs for Interactive Image Synthesis and Editing Ji Lin, Richard Zhang, Frieder Ganz, Song Han, Jun-Yan Zh

MIT HAN Lab 726 Dec 28, 2022
Fast and customizable reconnaissance workflow tool based on simple YAML based DSL.

Fast and customizable reconnaissance workflow tool based on simple YAML based DSL, with support of notifications and distributed workload of that work

Américo Júnior 3 Mar 11, 2022
A clean implementation based on AlphaZero for any game in any framework + tutorial + Othello/Gobang/TicTacToe/Connect4 and more

Alpha Zero General (any game, any framework!) A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play

Surag Nair 3.1k Jan 05, 2023
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022