The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization".

Overview

Codebase for learning control flow in transformers

The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization".

Paper: https://arxiv.org/abs/2110.07732

Please note that this repository is a cleaned-up version of the internal research repository we use. In case you encounter any problems with it, please don't hesitate to contact me.

Setup

This project requires Python 3 (tested with Python 3.8 and 3.9) and PyTorch 1.8.

pip3 install -r requirements.txt

Create a Weights and Biases account and run

wandb login

More information on setting up Weights and Biases can be found on https://docs.wandb.com/quickstart.

For plotting, LaTeX is required (to avoid Type 3 fonts and to render symbols). Installation is OS specific.

Usage

Running the experiments from the paper on a cluster

The code makes use of Weights and Biases for experiment tracking. In the sweeps directory, we provide sweep configurations for all experiments we have performed. The sweeps are officially meant for hyperparameter optimization, but we use them to run multiple configurations and seeds.

To reproduce our results, start a sweep for each of the YAML files in the sweeps directory. Run wandb agent for each of them in the root directory of the project. This will run all the experiments, and they will be displayed on the W&B dashboard. The name of the sweeps must match the name of the files in sweeps directory, except the .yaml ending. More details on how to run W&B sweeps can be found at https://docs.wandb.com/sweeps/quickstart. If you want to use a Linux cluster to run the experiments, you might find https://github.com/robertcsordas/cluster_tool useful.

For example, if you want to run NDR on compositional table lookup, run wandb sweep --name ctl_ndr sweeps/ctl_ndr.yaml. This creates the sweep and prints out its ID. Then run wandb agent <ID> with that ID.

Re-creating plots from the paper

Edit config file paper/config.json. Enter your project name in the field "wandb_project" (e.g. "username/project").

Run the scripts in the paper directory. For example:

cd paper
./run_all.sh

The output will be generated in the paper/out/ directory. Tables will be printed to stdout in latex format.

If you want to reproduce individual plots, it can be done by running individial python files in the paper directory.

Running experiments locally

It is possible to run single experiments with Tensorboard without using Weights and Biases. This is intended to be used for debugging the code locally.

If you want to run experiments locally, you can use run.py:

./run.py sweeps/ctl_ndr.yaml

If the sweep in question has multiple parameter choices, run.py will interactively prompt choices of each of them.

The experiment also starts a Tensorboard instance automatically on port 7000. If the port is already occupied, it will incrementally search for the next free port.

Note that the plotting scripts work only with Weights and Biases.

Reducing memory usage

In case some tasks won't fit on your GPU, play around with "-max_length_per_batch " argument. It can trade off memory usage/speed by slicing batches and executing them in multiple passes. Reduce it until the model fits.

BibText

@article{csordas2021neural,
      title={The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization}, 
      author={R\'obert Csord\'as and Kazuki Irie and J\"urgen Schmidhuber},
      journal={Preprint arXiv:2110.07732},
      year={2021},
      month={October}
}
Owner
Csordás Róbert
Csordás Róbert
Encoding Causal Macrovariables

Encoding Causal Macrovariables Data Natural climate data ('El Nino') Self-generated data ('Simulated') Experiments Detecting macrovariables through th

Benedikt Höltgen 3 Jul 31, 2022
This is a project based on ConvNets used to identify whether a road is clean or dirty. We have used MobileNet as our base architecture and the weights are based on imagenet.

PROJECT TITLE: CLEAN/DIRTY ROAD DETECTION USING TRANSFER LEARNING Description: This is a project based on ConvNets used to identify whether a road is

Faizal Karim 3 Nov 06, 2022
Implementation of Stochastic Image-to-Video Synthesis using cINNs.

Stochastic Image-to-Video Synthesis using cINNs Official PyTorch implementation of Stochastic Image-to-Video Synthesis using cINNs accepted to CVPR202

CompVis Heidelberg 135 Dec 28, 2022
This program presents convolutional kernel density estimation, a method used to detect intercritical epilpetic spikes (IEDs)

Description This program presents convolutional kernel density estimation, a method used to detect intercritical epilpetic spikes (IEDs) in [Gardy et

Ludovic Gardy 0 Feb 09, 2022
Computations and statistics on manifolds with geometric structures.

Geomstats Code Continuous Integration Code coverage (numpy) Code coverage (autograd, tensorflow, pytorch) Documentation Community NEWS: Geomstats is r

875 Dec 31, 2022
Official code of "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network."

R2RNet Official code of "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network." Jiang Hai, Zhu Xuan, Ren Yang, Yutong Hao, Fengzhu

77 Dec 24, 2022
Self-Supervised Monocular DepthEstimation with Internal Feature Fusion(arXiv), BMVC2021

DIFFNet This repo is for Self-Supervised Monocular Depth Estimation with Internal Feature Fusion(arXiv), BMVC2021 A new backbone for self-supervised d

Hang 94 Dec 25, 2022
Official code of paper "PGT: A Progressive Method for Training Models on Long Videos" on CVPR2021

PGT Code for paper PGT: A Progressive Method for Training Models on Long Videos. Install Run pip install -r requirements.txt. Run python setup.py buil

Bo Pang 27 Mar 30, 2022
GeDML is an easy-to-use generalized deep metric learning library

GeDML is an easy-to-use generalized deep metric learning library

Borui Zhang 32 Dec 05, 2022
When are Iterative GPs Numerically Accurate?

When are Iterative GPs Numerically Accurate? This is a code repository for the paper "When are Iterative GPs Numerically Accurate?" by Wesley Maddox,

Wesley Maddox 1 Jan 06, 2022
Official repository for: Continuous Control With Ensemble DeepDeterministic Policy Gradients

Continuous Control With Ensemble Deep Deterministic Policy Gradients This repository is the official implementation of Continuous Control With Ensembl

4 Dec 06, 2021
Unofficial PyTorch code for BasicVSR

Dependencies and Installation The code is based on BasicSR, Please install the BasicSR framework first. Pytorch=1.51 Training cd ./code CUDA_VISIBLE_

Long 59 Dec 06, 2022
a practicable framework used in Deep Learning. So far UDL only provide DCFNet implementation for the ICCV paper (Dynamic Cross Feature Fusion for Remote Sensing Pansharpening)

UDL UDL is a practicable framework used in Deep Learning (computer vision). Benchmark codes, results and models are available in UDL, please contact @

Xiao Wu 11 Sep 30, 2022
PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".

Sharpness-aware Quantization for Deep Neural Networks This is the official repository for our paper: Sharpness-aware Quantization for Deep Neural Netw

Zhuang AI Group 30 Dec 19, 2022
[2021 MultiMedia] CONQUER: Contextual Query-aware Ranking for Video Corpus Moment Retrieval

CONQUER: Contexutal Query-aware Ranking for Video Corpus Moment Retreival PyTorch implementation of CONQUER: Contexutal Query-aware Ranking for Video

Hou zhijian 23 Dec 26, 2022
Code repo for "RBSRICNN: Raw Burst Super-Resolution through Iterative Convolutional Neural Network" (Machine Learning and the Physical Sciences workshop in NeurIPS 2021).

RBSRICNN: Raw Burst Super-Resolution through Iterative Convolutional Neural Network An official PyTorch implementation of the RBSRICNN network as desc

Rao Muhammad Umer 6 Nov 14, 2022
General Vision Benchmark, a project from OpenGVLab

Introduction We build GV-B(General Vision Benchmark) on Classification, Detection, Segmentation and Depth Estimation including 26 datasets for model e

174 Dec 27, 2022
A FAIR dataset of TCV experimental results for validating edge/divertor turbulence models.

TCV-X21 validation for divertor turbulence simulations Quick links Intro Welcome to TCV-X21. We're glad you've found us! This repository is designed t

0 Dec 18, 2021
A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution

DRSAN A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution Karam Park, Jae Woong Soh, and Nam Ik Cho Environments U

4 May 10, 2022
Generative Adversarial Networks(GANs)

Generative Adversarial Networks(GANs) Vanilla GAN ClusterGAN Vanilla GAN Model Structure Final Generator Structure A MLP with 2 hidden layers of hidde

Zhenbang Feng 2 Nov 05, 2021