Tree LSTM implementation in PyTorch

Overview

Tree-Structured Long Short-Term Memory Networks

This is a PyTorch implementation of Tree-LSTM as described in the paper Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks by Kai Sheng Tai, Richard Socher, and Christopher Manning. On the semantic similarity task using the SICK dataset, this implementation reaches:

  • Pearson's coefficient: 0.8492 and MSE: 0.2842 using hyperparameters --lr 0.010 --wd 0.0001 --optim adagrad --batchsize 25
  • Pearson's coefficient: 0.8674 and MSE: 0.2536 using hyperparameters --lr 0.025 --wd 0.0001 --optim adagrad --batchsize 25 --freeze_embed
  • Pearson's coefficient: 0.8676 and MSE: 0.2532 are the numbers reported in the original paper.
  • Known differences include the way the gradients are accumulated (normalized by batchsize or not).

Requirements

  • Python (tested on 3.6.5, should work on >=2.7)
  • Java >= 8 (for Stanford CoreNLP utilities)
  • Other dependencies are in requirements.txt Note: Currently works with PyTorch 0.4.0. Switch to the pytorch-v0.3.1 branch if you want to use PyTorch 0.3.1.

Usage

Before delving into how to run the code, here is a quick overview of the contents:

  • Use the script fetch_and_preprocess.sh to download the SICK dataset, Stanford Parser and Stanford POS Tagger, and Glove word vectors (Common Crawl 840) -- Warning: this is a 2GB download!), and additionally preprocees the data, i.e. generate dependency parses using Stanford Neural Network Dependency Parser.
  • main.pydoes the actual heavy lifting of training the model and testing it on the SICK dataset. For a list of all command-line arguments, have a look at config.py.
    • The first run caches GLOVE embeddings for words in the SICK vocabulary. In later runs, only the cache is read in during later runs.
    • Logs and model checkpoints are saved to the checkpoints/ directory with the name specified by the command line argument --expname.

Next, these are the different ways to run the code here to train a TreeLSTM model.

Local Python Environment

If you have a working Python3 environment, simply run the following sequence of steps:

- bash fetch_and_preprocess.sh
- pip install -r requirements.txt
- python main.py

Pure Docker Environment

If you want to use a Docker container, simply follow these steps:

- docker build -t treelstm .
- docker run -it treelstm bash
- bash fetch_and_preprocess.sh
- python main.py

Local Filesystem + Docker Environment

If you want to use a Docker container, but want to persist data and checkpoints in your local filesystem, simply follow these steps:

- bash fetch_and_preprocess.sh
- docker build -t treelstm .
- docker run -it --mount type=bind,source="$(pwd)",target="/root/treelstm.pytorch" treelstm bash
- python main.py

NOTE: Setting the environment variable OMP_NUM_THREADS=1 usually gives a speedup on the CPU. Use it like OMP_NUM_THREADS=1 python main.py. To run on a GPU, set the CUDA_VISIBLE_DEVICES instead. Usually, CUDA does not give much speedup here, since we are operating at a batchsize of 1.

Notes

  • (Apr 02, 2018) Added Dockerfile
  • (Apr 02, 2018) Now works on PyTorch 0.3.1 and Python 3.6, removed dependency on Python 2.7
  • (Nov 28, 2017) Added frozen embeddings, closed gap to paper.
  • (Nov 08, 2017) Refactored model to get 1.5x - 2x speedup.
  • (Oct 23, 2017) Now works with PyTorch 0.2.0.
  • (May 04, 2017) Added support for sparse tensors. Using the --sparse argument will enable sparse gradient updates for nn.Embedding, potentially reducing memory usage.
    • There are a couple of caveats, however, viz. weight decay will not work in conjunction with sparsity, and results from the original paper might not be reproduced using sparse embeddings.

Acknowledgements

Shout-out to Kai Sheng Tai for the original LuaTorch implementation, and to the Pytorch team for the fun library.

Contact

Riddhiman Dasgupta

This is my first PyTorch based implementation, and might contain bugs. Please let me know if you find any!

License

MIT

Owner
Riddhiman Dasgupta
Deep Learning, Science Fiction, Comic Books
Riddhiman Dasgupta
Code and project page for ICCV 2021 paper "DisUnknown: Distilling Unknown Factors for Disentanglement Learning"

DisUnknown: Distilling Unknown Factors for Disentanglement Learning See introduction on our project page Requirements PyTorch = 1.8.0 torch.linalg.ei

Sitao Xiang 24 May 16, 2022
Our solution for SSN Invente 2021's Hackathon

Our solution for SSN Invente 2021's Hackathon. To help maitain godowns in a pristine and safe condition using raspberry pi.

1 Jan 12, 2022
docTR by Mindee (Document Text Recognition) - a seamless, high-performing & accessible library for OCR-related tasks powered by Deep Learning.

docTR by Mindee (Document Text Recognition) - a seamless, high-performing & accessible library for OCR-related tasks powered by Deep Learning.

Mindee 1.5k Jan 01, 2023
Implements the training, testing and editing tools for "Pluralistic Image Completion"

Pluralistic Image Completion ArXiv | Project Page | Online Demo | Video(demo) This repository implements the training, testing and editing tools for "

Chuanxia Zheng 615 Dec 08, 2022
Rainbow DQN implementation that outperforms the paper's results on 40% of games using 20x less data 🌈

Rainbow 🌈 An implementation of Rainbow DQN which outperforms the paper's (Hessel et al. 2017) results on 40% of tested games while using 20x less dat

Dominik Schmidt 31 Dec 21, 2022
Code for "Learning the Best Pooling Strategy for Visual Semantic Embedding", CVPR 2021

Learning the Best Pooling Strategy for Visual Semantic Embedding Official PyTorch implementation of the paper Learning the Best Pooling Strategy for V

Jiacheng Chen 106 Jan 06, 2023
LLVM-based compiler for LightGBM gradient-boosted trees. Speeds up prediction by ≥10x.

LLVM-based compiler for LightGBM gradient-boosted trees. Speeds up prediction by ≥10x.

Simon Boehm 183 Jan 02, 2023
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
PyTorch implementation of 'Gen-LaneNet: a generalized and scalable approach for 3D lane detection'

(pytorch) Gen-LaneNet: a generalized and scalable approach for 3D lane detection Introduction This is a pytorch implementation of Gen-LaneNet, which p

Yuliang Guo 233 Jan 06, 2023
code for Grapadora research paper experimentation

Road feature embedding selection method Code for research paper experimentation Abstract Traffic forecasting models rely on data that needs to be sens

Eric López Manibardo 0 May 26, 2022
CCP dataset from Clothing Co-Parsing by Joint Image Segmentation and Labeling

Clothing Co-Parsing (CCP) Dataset Clothing Co-Parsing (CCP) dataset is a new clothing database including elaborately annotated clothing items. 2, 098

Wei Yang 434 Dec 24, 2022
An official TensorFlow implementation of “CLCC: Contrastive Learning for Color Constancy” accepted at CVPR 2021.

CLCC: Contrastive Learning for Color Constancy (CVPR 2021) Yi-Chen Lo*, Chia-Che Chang*, Hsuan-Chao Chiu, Yu-Hao Huang, Chia-Ping Chen, Yu-Lin Chang,

Yi-Chen (Howard) Lo 58 Dec 17, 2022
Revisiting Self-Training for Few-Shot Learning of Language Model.

SFLM This is the implementation of the paper Revisiting Self-Training for Few-Shot Learning of Language Model. SFLM is short for self-training for few

15 Nov 19, 2022
GPOEO is a micro-intrusive GPU online energy optimization framework for iterative applications

GPOEO GPOEO is a micro-intrusive GPU online energy optimization framework for iterative applications. We also implement ODPP [1] as a comparison. [1]

瑞雪轻飏 8 Sep 10, 2022
Exploring Image Deblurring via Blur Kernel Space (CVPR'21)

Exploring Image Deblurring via Encoded Blur Kernel Space About the project We introduce a method to encode the blur operators of an arbitrary dataset

VinAI Research 118 Dec 19, 2022
TGRNet: A Table Graph Reconstruction Network for Table Structure Recognition

TGRNet: A Table Graph Reconstruction Network for Table Structure Recognition Xue, Wenyuan, et al. "TGRNet: A Table Graph Reconstruction Network for Ta

Wenyuan 68 Jan 04, 2023
fastgradio is a python library to quickly build and share gradio interfaces of your trained fastai models.

fastgradio is a python library to quickly build and share gradio interfaces of your trained fastai models.

Ali Abdalla 34 Jan 05, 2023
Implementation of ReSeg using PyTorch

Implementation of ReSeg using PyTorch ReSeg: A Recurrent Neural Network-based Model for Semantic Segmentation Pascal-Part Annotations Pascal VOC 2010

Onur Kaplan 46 Nov 23, 2022
View model summaries in PyTorch!

torchinfo (formerly torch-summary) Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensor

Tyler Yep 1.5k Jan 05, 2023
Making a music video with Wav2CLIP and VQGAN-CLIP

music2video Overview A repo for making a music video with Wav2CLIP and VQGAN-CLIP. The base code was derived from VQGAN-CLIP The CLIP embedding for au

Joel Jang | 장요엘 163 Dec 26, 2022