[ICLR'19] Trellis Networks for Sequence Modeling

Overview

TrellisNet for Sequence Modeling

PWC PWC

This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico Kolter and Vladlen Koltun.

On the one hand, a trellis network is a temporal convolutional network with special structure, characterized by weight tying across depth and direct injection of the input into deep layers. On the other hand, we show that truncated recurrent networks are equivalent to trellis networks with special sparsity structure in their weight matrices. Thus trellis networks with general weight matrices generalize truncated recurrent networks. This allows trellis networks to serve as bridge between recurrent and convolutional architectures, benefitting from algorithmic and architectural techniques developed in either context. We leverage these relationships to design high-performing trellis networks that absorb ideas from both architectural families. Experiments demonstrate that trellis networks outperform the current state of the art on a variety of challenging benchmarks, including word-level language modeling on Penn Treebank and WikiText-103 (UPDATE: recently surpassed by Transformer-XL), character-level language modeling on Penn Treebank, and stress tests designed to evaluate long-term memory retention.

Our experiments were done in PyTorch. If you find our work, or this repository helpful, please consider citing our work:

@inproceedings{bai2018trellis,
  author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
  title     = {Trellis Networks for Sequence Modeling},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {2019},
}

Datasets

The code should be directly runnable with PyTorch 1.0.0 or above. This repository contains the training script for the following tasks:

  • Sequential MNIST handwritten digit classification
  • Permuted Sequential MNIST that randomly permutes the pixel order in sequential MNIST
  • Sequential CIFAR-10 classification (more challenging, due to more intra-class variations, channel complexities and larger images)
  • Penn Treebank (PTB) word-level language modeling (with and without the mixture of softmax); vocabulary size 10K
  • Wikitext-103 (WT103) large-scale word-level language modeling; vocabulary size 268K
  • Penn Treebank medium-scale character-level language modeling

Note that these tasks are on very different scales, with unique properties that challenge sequence models in different ways. For example, word-level PTB is a small dataset that a typical model easily overfits, so judicious regularization is essential. WT103 is a hundred times larger, with less danger of overfitting, but with a vocabulary size of 268K that makes training more challenging (due to large embedding size).

Pre-trained Model(s)

We provide some reasonably good pre-trained weights here so that the users don't need to train from scratch. We'll update the table from time to time. (Note: if you train from scratch using different seeds, it's likely you will get better results :-))

Description Task Dataset Model
TrellisNet-LM Word-Level Language Modeling Penn Treebank (PTB) download (.pkl)
TrellisNet-LM Character-Level Language Modeling Penn Treebank (PTB) download (.pkl)

To use the pre-trained weights, use the flag --load_weight [.pkl PATH] when starting the training script (e.g., you can just use the default arg parameters). You can use the flag --eval turn on the evaluation mode only.

Usage

All tasks share the same underlying TrellisNet model, which is in file trellisnet.py (and the eventual models, including components like embedding layer, are in model.py). As discussed in the paper, TrellisNet is able to benefit significantly from techniques developed originally for RNNs as well as temporal convolutional networks (TCNs). Some of these techniques are also included in this repository. Each task is organized in the following structure:

[TASK_NAME] /
    data/
    logs/
    [TASK_NAME].py
    model.py
    utils.py
    data.py

where [TASK_NAME].py is the training script for the task (with argument flags; use -h to see the details).

Owner
CMU Locus Lab
Zico Kolter's Research Group
CMU Locus Lab
A tutorial showing how to train, convert, and run TensorFlow Lite object detection models on Android devices, the Raspberry Pi, and more!

A tutorial showing how to train, convert, and run TensorFlow Lite object detection models on Android devices, the Raspberry Pi, and more!

Evan 1.3k Jan 02, 2023
Code for 'Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning', ICCV 2021

CMIC-Retrieval Code for Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning. ICCV 2021. Introduction In this wo

42 Nov 17, 2022
A python program to hack instagram

hackinsta a program to hack instagram Yokoback_(instahack) is the file to open, you need libraries write on import. You run that file in the same fold

2 Jan 22, 2022
Implementation of Continuous Sparsification, a method for pruning and ticket search in deep networks

Continuous Sparsification Implementation of Continuous Sparsification (CS), a method based on l_0 regularization to find sparse neural networks, propo

Pedro Savarese 23 Dec 07, 2022
NExT-QA: Next Phase of Question-Answering to Explaining Temporal Actions (CVPR2021)

NExT-QA We reproduce some SOTA VideoQA methods to provide benchmark results for our NExT-QA dataset accepted to CVPR2021 (with 1 'Strong Accept' and 2

Junbin Xiao 50 Nov 24, 2022
Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

[AAAI2022] UCTransNet This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspectiv

Haonan Wang 199 Jan 03, 2023
code for the ICLR'22 paper: On Robust Prefix-Tuning for Text Classification

On Robust Prefix-Tuning for Text Classification Prefix-tuning has drawed much attention as it is a parameter-efficient and modular alternative to adap

Zonghan Yang 12 Nov 30, 2022
DGCNN - Dynamic Graph CNN for Learning on Point Clouds

DGCNN is the author's re-implementation of Dynamic Graph CNN, which achieves state-of-the-art performance on point-cloud-related high-level tasks including category classification, semantic segmentat

Wang, Yue 1.3k Dec 26, 2022
A demonstration of using a live Tensorflow session to create an interactive face-GAN explorer.

Streamlit Demo: The Controllable GAN Face Generator This project highlights Streamlit's new hash_func feature with an app that calls on TensorFlow to

Streamlit 257 Dec 31, 2022
Efficient Training of Audio Transformers with Patchout

PaSST: Efficient Training of Audio Transformers with Patchout This is the implementation for Efficient Training of Audio Transformers with Patchout Pa

165 Dec 26, 2022
Multi Camera Calibration

Multi Camera Calibration 'modules/camera_calibration/app/camera_calibration.cpp' is for calculating extrinsic parameter of each individual cameras. 'm

7 Dec 01, 2022
A generalist algorithm for cell and nucleus segmentation.

Cellpose | A generalist algorithm for cell and nucleus segmentation. Cellpose was written by Carsen Stringer and Marius Pachitariu. To learn about Cel

MouseLand 733 Dec 29, 2022
Breast cancer is been classified into benign tumour and malignant tumour.

Breast cancer is been classified into benign tumour and malignant tumour. Logistic regression is applied in this model.

1 Feb 04, 2022
(Personalized) Page-Rank computation using PyTorch

torch-ppr This package allows calculating page-rank and personalized page-rank via power iteration with PyTorch, which also supports calculation on GP

Max Berrendorf 69 Dec 03, 2022
This is the pytorch code for the paper Curious Representation Learning for Embodied Intelligence.

Curious Representation Learning for Embodied Intelligence This is the pytorch code for the paper Curious Representation Learning for Embodied Intellig

19 Oct 19, 2022
Official repository of the paper Privacy-friendly Synthetic Data for the Development of Face Morphing Attack Detectors

SMDD-Synthetic-Face-Morphing-Attack-Detection-Development-dataset Official repository of the paper Privacy-friendly Synthetic Data for the Development

10 Dec 12, 2022
Official Pytorch implementation of MixMo framework

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks Official PyTorch implementation of the MixMo framework | paper | docs Alexandr

79 Nov 07, 2022
Codes for ACL-IJCNLP 2021 Paper "Zero-shot Fact Verification by Claim Generation"

Zero-shot-Fact-Verification-by-Claim-Generation This repository contains code and models for the paper: Zero-shot Fact Verification by Claim Generatio

Liangming Pan 47 Jan 01, 2023
Data, notebooks, and articles associated with the RSNA AI Deep Learning Lab at RSNA 2021

RSNA AI Deep Learning Lab 2021 Intro Welcome Deep Learners! This document provides all the information you need to participate in the RSNA AI Deep Lea

RSNA 65 Dec 16, 2022