Sequence Modeling with Structured State Spaces

Overview

Structured State Spaces for Sequence Modeling

This repository provides implementations and experiments for the following papers.

S4

Structured State Spaces

Efficiently Modeling Long Sequences with Structured State Spaces
Albert Gu, Karan Goel, Christopher Ré
Paper: https://arxiv.org/abs/2111.00396

LSSL

Linear State Space Layer

Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer
Albert Gu, Isys Johnson, Karan Goel, Khaled Saab, Tri Dao, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2110.13985

HiPPO

HiPPO Framework

HiPPO: Recurrent Memory with Optimal Polynomial Projections
Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2008.07669

Setup

Requirements

This repository requires Python 3.8+ and Pytorch 1.9+. Other packages are listed in requirements.txt.

Data

Datasets and Dataloaders

All logic for creating and loading datasets is in src/dataloaders. This folders includes many old and experimental datasets. The datasets that we consider core are located in src/dataloaders/datasets.py.

The raw data should be organized as follows. The data path can be configured by the environment variable DATA_PATH, or defaults to ./data by default, where . is the top level directory of this repository (e.g. 'state-spaces').

Data

External datasets include Long Range Arena (LRA), which can be downloaded from their GitHub page.

These external datasets should be organized as follows:

DATA_PATH/
  pathfinder/
    pathfinder32/
    pathfinder64/
    pathfinder128/
    pathfinder256/
  aan/
  listops/

Fine-grained control over the data directory is allowed, e.g. if the LRA ListOps files are located in /home/lra/listops-1000/, you can pass in +dataset.data_dir=/home/lra/listops-1000 on the command line

Cauchy Kernel

A core operation of S4 is the "Cauchy kernel" described in the paper. The implementation of this requires one of two methods:

Custom CUDA Kernel

This version is faster but requires manual compilation on each machine. Run python setup.py install from the directory extensions/cauchy/.

Pykeops

This version is provided by the pykeops library. Installation usually works out of the box with pip install pykeops cmake which are provided in the requirements file.

Note that running in a Colab requires installing a different pip package; instructions can be found in the pykeops documentation.

S4 Experiments

This section describes how to use the latest S4 model and reproduce experiments immediately. More detailed descriptions of the infrastructure are in the subsequent sections.

Structured State Space (S4)

The S4 module is found at src/models/sequence/ss/s4.py.

For users who would like to import a single file that has the self-contained S4 layer, a standalone version can be found at src/models/sequence/ss/standalone/s4.py.

Testing

For testing, we frequently use synthetic datasets or the Permuted MNIST dataset. This can be run with python -m train wandb=null pipeline=mnist model=s4, which should get to around 90% after 1 epoch which takes 2-4 minutes depending on GPU.

Long Range Arena (LRA)

python -m train wandb=null experiment=s4-lra-listops
python -m train wandb=null experiment=s4-lra-imdb
python -m train wandb=null experiment=s4-lra-cifar
python -m train wandb=null experiment=s4-lra-aan
python -m train wandb=null experiment=s4-lra-pathfinder
python -m train wandb=null experiment=s4-lra-pathx

Note that these experiments may take different amounts of time to train. IMDB should take just 1-2 hours, while Path-X will take several epochs to take off and take over a day to train to completion.

CIFAR-10

python -m train wandb=null experiment=s4-cifar

The above command line reproduces our best sequential CIFAR model. Decreasing the model size should yield close results, e.g. halving the hidden dimension with model.d_model=512.

Speech Commands

The Speech Commands dataset we compare against is a modified smaller 10-way classification task.

python -m train wandb=null experiment=s4-sc

To use the original version with the full 35 classes, pass in dataset.all_classes=true

Training

The core training infrastructure of this repository is based on Pytorch-Lightning with a configuration scheme based on Hydra. The structure of this integration largely follows the Lightning+Hydra integration template described in https://github.com/ashleve/lightning-hydra-template.

The main experiment entrypoint is train.py and configs are found in configs/. In brief, the main config is found at configs/config.yaml, which is combined with other sets of configs that can be passed on the command line, to define an overall YAML config. Most config groups define one single Python object (e.g. a PyTorch nn.Module). The end-to-end training pipeline can broken down into the following rough groups, where group XX is found under configs/XX/:

model: the sequence-to-sequence model backbone (e.g. a src.models.sequence.SequenceModel)
dataset: the raw dataset (data/target pairs) (e.g. a pytorch Dataset)
loader: how the data is loaded (e.g. a pytorch DataLoader)
encoder: defines a Module that interfaces between data and model backbone
decoder: defines a Module that interfaces between model backbone and targets
task: specifies loss and metrics

Default combinations of dataset+loader+encoder+decoder+task are further consolidated into groups called pipelines.

A run can be performed by passing in a pipeline config, model config, and any additional arguments modifying the default configurations. A simple example experiment is

python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null

This uses the permuted sequential MNIST task and uses an s4 model with a specified number of layers, backbone dimension, and normalization type.

Hydra

It is recommended to read the Hydra documentation to fully understand the configuration framework. For help launching specific experiments, please file an Issue.

Registries

This codebase uses a modification of the hydra instantiate utility that provides shorthand names of different classes, for convenience in configuration and logging. The mapping from shorthand to full path can be found in src/utils/registry.py.

WandB

Logging with WandB is built into this repository. In order to use this, simply set your WANDB_API_KEY environment variable, and change the wandb.project attribute of configs/config.yaml (or pass it on the command line python -m train .... wandb.project=s4).

Set wandb=null to turn off WandB logging.

Models

This repository provides a modular and flexible implementation of sequence models at large.

SequenceModule

SequenceModule src/models/sequence/base.py is the abstract interface that all sequence models adhere to. In this codebase, sequence models are defined as a sequence-to-sequence map of shape (batch size, sequence length, input dimension) to (batch size, sequence length, output dimension).

The SequenceModule comes with other methods such as step which is meant for autoregressive settings, and logic to carry optional hidden states (for stateful models such as RNNs or S4).

SequenceModel

SequenceModel src/models/sequence/model.py is the main backbone with configurable options for residual function, normalization placement and type, etc. SequenceModel accepts a black box config for a layer. Compatible layers are SequenceModules (i.e. composable sequence transformations) found under src/models/sequence/.

S4

This is the main model of this repository. See instructions in Getting Started.

LSSL

The LSSL is an old version of S4. It is currently not recommended for use, but the model can be found at src/models/sequence/ss/lssl.py.

It can be run with model/layer=lssl or model/layer=lssl model.layer.learn=0 for the LSSL-fixed model which does not train A, B, or dt.

HiPPO

HiPPO is the mathematical framework upon which the papers HiPPO, LSSL, and S4 are built on. The logic for HiPPO operators is found under src/models/hippo/.

HiPPO-RNN cells from the original [https://arxiv.org/abs/2008.07669] can be found under the RNN cells

RNNs

This codebase contains a flexible and modular implementation of many RNN cells.

Some examples include model=rnn/hippo-legs and model=rnn/hippo-legt for HiPPO variants from the original paper, or model=rnn/gru for a GRU reimplementation, etc.

An exception is model=lstm to use the PyTorch LSTM.

Example command (reproducing the Permuted MNIST number from the HiPPO paper, which was SotA at the time):

python train.py pipeline=mnist model=rnn/hippo-legs model.cell_args.hidden_size=512 train.epochs=50 train.batch_size=100 train.lr=0.001

Baselines

Other sequence models are easily incorporated into this repository, and several other baselines have been ported.

These include CNNs such as the WaveGAN Discriminator and CKConv and continuous-time/RNN models such as UnICORNN and LipschitzRNN.

python -m train dataset=mnist model={ckconv,unicornn}

Overall Repository Structure

configs/         config files for model, data pipeline, training loop, etc.
data/            default location of raw data
extensions/      CUDA extension for Cauchy kernel
src/             main source code for models, datasets, etc.
train.py         main entrypoint

Citation

If you use this codebase, or otherwise found our work valuable, please cite:

@article{gu2021efficiently,
  title={Efficiently Modeling Long Sequences with Structured State Spaces},
  author={Gu, Albert and Goel, Karan and R{\'e}, Christopher},
  journal={arXiv preprint arXiv:2111.00396},
  year={2021}
}

@article{gu2021combining,
  title={Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers},
  author={Gu, Albert and Johnson, Isys and Goel, Karan and Saab, Khaled and Dao, Tri and Rudra, Atri and R{\'e}, Christopher},
  journal={Advances in neural information processing systems},
  volume={34},
  year={2021}
}

@article{gu2020hippo,
  title={HiPPO: Recurrent Memory with Optimal Polynomial Projections},
  author={Gu, Albert and Dao, Tri and Ermon, Stefano and Rudra, Atri and Re, Christopher},
  journal={Advances in neural information processing systems},
  volume={33},
  year={2020}
}
Owner
HazyResearch
We are a CS research group led by Prof. Chris Ré.
HazyResearch
Axel - 3D printed robotic hands and they controll with Raspberry Pi and Arduino combo

Axel It's our graduation project about 3D printed robotic hands and they control

0 Feb 14, 2022
MatryODShka: Real-time 6DoF Video View Synthesis using Multi-Sphere Images

Main repo for ECCV 2020 paper MatryODShka: Real-time 6DoF Video View Synthesis using Multi-Sphere Images. visual.cs.brown.edu/matryodshka

Brown University Visual Computing Group 75 Dec 13, 2022
CCP dataset from Clothing Co-Parsing by Joint Image Segmentation and Labeling

Clothing Co-Parsing (CCP) Dataset Clothing Co-Parsing (CCP) dataset is a new clothing database including elaborately annotated clothing items. 2, 098

Wei Yang 434 Dec 24, 2022
A large-scale video dataset for the training and evaluation of 3D human pose estimation models

ASPset-510 ASPset-510 (Australian Sports Pose Dataset) is a large-scale video dataset for the training and evaluation of 3D human pose estimation mode

Aiden Nibali 36 Oct 30, 2022
Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research

Megaverse Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research. The efficient design of the engine enables ph

Aleksei Petrenko 191 Dec 23, 2022
Code and data for the paper "Hearing What You Cannot See"

Hearing What You Cannot See: Acoustic Vehicle Detection Around Corners Public repository of the paper "Hearing What You Cannot See: Acoustic Vehicle D

TU Delft Intelligent Vehicles 26 Jul 13, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
OstrichRL: A Musculoskeletal Ostrich Simulation to Study Bio-mechanical Locomotion.

OstrichRL This is the repository accompanying the paper OstrichRL: A Musculoskeletal Ostrich Simulation to Study Bio-mechanical Locomotion. It contain

Vittorio La Barbera 51 Nov 17, 2022
Implicit Graph Neural Networks

Implicit Graph Neural Networks This repository is the official PyTorch implementation of "Implicit Graph Neural Networks". Fangda Gu*, Heng Chang*, We

Heng Chang 48 Nov 29, 2022
Machine Learning in Asset Management (by @firmai)

Machine Learning in Asset Management If you like this type of content then visit ML Quant site below: https://www.ml-quant.com/ Part One Follow this l

Derek Snow 1.5k Jan 02, 2023
The dynamics of representation learning in shallow, non-linear autoencoders

The dynamics of representation learning in shallow, non-linear autoencoders The package is written in python and uses the pytorch implementation to ML

Maria Refinetti 4 Jun 08, 2022
Distributed Evolutionary Algorithms in Python

DEAP DEAP is a novel evolutionary computation framework for rapid prototyping and testing of ideas. It seeks to make algorithms explicit and data stru

Distributed Evolutionary Algorithms in Python 4.9k Jan 05, 2023
NLP made easy

GluonNLP: Your Choice of Deep Learning for NLP GluonNLP is a toolkit that helps you solve NLP problems. It provides easy-to-use tools that helps you l

Distributed (Deep) Machine Learning Community 2.5k Jan 04, 2023
[CVPR 2022 Oral] EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation

EPro-PnP EPro-PnP: Generalized End-to-End Probabilistic Perspective-n-Points for Monocular Object Pose Estimation In CVPR 2022 (Oral). [paper] Hanshen

同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 842 Jan 04, 2023
[ICCV 2021 Oral] SnowflakeNet: Point Cloud Completion by Snowflake Point Deconvolution with Skip-Transformer

This repository contains the source code for the paper SnowflakeNet: Point Cloud Completion by Snowflake Point Deconvolution with Skip-Transformer (ICCV 2021 Oral). The project page is here.

AllenXiang 65 Dec 26, 2022
Tom-the-AI - A compound artificial intelligence software for Linux systems.

Tom the AI (version 0.82) WARNING: This software is not yet ready to use, I'm still setting up the GitHub repository. Should be ready in a few days. T

2 Apr 28, 2022
The implementation for "Comprehensive Knowledge Distillation with Causal Intervention".

Comprehensive Knowledge Distillation with Causal Intervention This repository is a PyTorch implementation of "Comprehensive Knowledge Distillation wit

Xiang Deng 10 Nov 03, 2022
PyTorch code for EMNLP 2021 paper: Don't be Contradicted with Anything! CI-ToD: Towards Benchmarking Consistency for Task-oriented Dialogue System

Don’t be Contradicted with Anything!CI-ToD: Towards Benchmarking Consistency for Task-oriented Dialogue System This repository contains the PyTorch im

Libo Qin 25 Sep 06, 2022
Adversarial Reweighting for Partial Domain Adaptation

Adversarial Reweighting for Partial Domain Adaptation Code for paper "Xiang Gu, Xi Yu, Yan Yang, Jian Sun, Zongben Xu, Adversarial Reweighting for Par

12 Dec 01, 2022
Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains

Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains This is an accompanying repository to the ICAIL 2021 pap

4 Dec 16, 2021