Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Overview

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding

PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) paper.

Method Description

Distilled Sentence Embedding (DSE) distills knowledge from a finetuned state-of-the-art transformer model (BERT) to create high quality sentence embeddings. For a complete description, as well as implementation details and hyperparameters, please refer to the paper.

Usage

Follow the instructions below in order to run the training procedure of the Distilled Sentence Embedding (DSE) method. The python scripts below can be run with the -h parameter to get more information.

1. Install Requirements

Tested with Python 3.6+.

pip install -r requirements.txt

2. Download GLUE Datasets

Run the download_glue_data.py script to download the GLUE datasets.

python download_glue_data.py

3. Finetune BERT on a Specific Task

Finetune a standard BERT model on a specific task (e.g., MRPC, MNLI, etc.). Below is an example for the MRPC dataset.

python finetune_bert.py \
--bert_model bert-large-uncased-whole-word-masking \
--task_name mrpc \
--do_train \
--do_eval \
--do_lower_case \
--data_dir glue_data/MRPC \
--max_seq_length 128 \
--train_batch_size 32 \
--gradient_accumulation_steps 2 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir outputs/large_uncased_finetuned_mrpc \
--overwrite_output_dir \
--no_parallel

Note: For our code to work with the AllNLI dataset (a combination of the MNLI and SNLI datasets), you simply need to create a folder where the downloaded GLUE datasets are and copy the MNLI and SNLI datasets into it.

4. Create Logits for Distillation from the Finetuned BERT

Execute the following command to create the logits which will be used for the distillation training objective. Note that the bert_checkpoint_dir parameter has to match the output_dir from the previous command.

python run_distillation_logits_creator.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--do_lower_case \
--train_features_path glue_data/MRPC/train_bert-large-uncased-whole-word-masking_128_mrpc \
--bert_checkpoint_dir outputs/large_uncased_finetuned_mrpc

5. Train the DSE Model using the Finetuned BERT Logits

Train the DSE model using the extracted logits. Notice that the distillation_logits_path parameter needs to be changed according to the task.

python dse_train_runner.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--distillation_logits_path outputs/logits/large_uncased_finetuned_mrpc_logits.pt \
--do_lower_case \
--file_log \
--epochs 8 \
--store_checkpoints \
--fc_dims 512 1

Important Notes:

  • To store checkpoints for the model make sure that the --store_checkpoints flag is passed as shown above.
  • The fc_dims parameter accepts a list of space separated integers, and is the dimensions of the fully connected classifier that is put on top of the extracted features from the Siamese DSE model. The output dimension (in this case 1) needs to be changed according to the wanted output dimensionality. For example, for the MNLI dataset the fc_dims parameter should be 512 3 since it is a 3 class classification task.

6. Loading the Trained DSE Model

During training, checkpoints of the Trainer object which contains the model will be saved. You can load these checkpoints and extract the model state dictionary from them. Then you can load the state into a created DSESiameseClassifier model. The load_dse_checkpoint_example.py script contains an example of how to do that.

To load the model trained with the example commands above you can use:

python load_dse_checkpoint_example.py \
--task_name mrpc \
--trainer_checkpoint <path_to_saved_checkpoint> \
--do_lower_case \
--fc_dims 512 1

Acknowledgments

Citation

If you find this code useful, please cite the following paper:

@inproceedings{barkan2020scalable,
  title={Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding},
  author={Barkan, Oren and Razin, Noam and Malkiel, Itzik and Katz, Ori and Caciularu, Avi and Koenigstein, Noam},
  booktitle={AAAI Conference on Artificial Intelligence (AAAI)},
  year={2020}
}
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

Wenhao Wang 89 Jan 02, 2023
The author's officially unofficial PyTorch BigGAN implementation.

BigGAN-PyTorch The author's officially unofficial PyTorch BigGAN implementation. This repo contains code for 4-8 GPU training of BigGANs from Large Sc

Andy Brock 2.6k Jan 02, 2023
EXplainable Artificial Intelligence (XAI)

EXplainable Artificial Intelligence (XAI) This repository includes the codes for different projects on eXplainable Artificial Intelligence (XAI) by th

4 Nov 28, 2022
Complete* list of autonomous driving related datasets

AD Datasets Complete* and curated list of autonomous driving related datasets Contributing Contributions are very welcome! To add or update a dataset:

Daniel Bogdoll 13 Dec 19, 2022
Cmsc11 arcade - Final Project for CMSC11

cmsc11_arcade Final Project for CMSC11 Developers: Limson, Mark Vincent Peñafiel

Gregory 1 Jan 18, 2022
Multiple paper open-source codes of the Microsoft Research Asia DKI group

📫 Paper Code Collection (MSRA DKI Group) This repo hosts multiple open-source codes of the Microsoft Research Asia DKI Group. You could find the corr

Microsoft 249 Jan 08, 2023
A scikit-learn-compatible module for estimating prediction intervals.

MAPIE - Model Agnostic Prediction Interval Estimator MAPIE allows you to easily estimate prediction intervals (or prediction sets) using your favourit

588 Jan 04, 2023
Bridging Vision and Language Model

BriVL BriVL (Bridging Vision and Language Model) 是首个中文通用图文多模态大规模预训练模型。BriVL模型在图文检索任务上有着优异的效果,超过了同期其他常见的多模态预训练模型(例如UNITER、CLIP)。 BriVL论文:WenLan: Bridgi

235 Dec 27, 2022
PyTorch implementation of ENet

PyTorch-ENet PyTorch (v1.1.0) implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation, ported from the lua-torc

David Silva 333 Dec 29, 2022
PyTorch implementation for Partially View-aligned Representation Learning with Noise-robust Contrastive Loss (CVPR 2021)

2021-CVPR-MvCLN This repo contains the code and data of the following paper accepted by CVPR 2021 Partially View-aligned Representation Learning with

XLearning Group 33 Nov 01, 2022
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Microsoft 1.3k Dec 30, 2022
A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swar.

Omni-swarm A Decentralized Omnidirectional Visual-Inertial-UWB State Estimation System for Aerial Swarm Introduction Omni-swarm is a decentralized omn

HKUST Aerial Robotics Group 99 Dec 23, 2022
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
Pytorch implementation of the paper "Topic Modeling Revisited: A Document Graph-based Neural Network Perspective"

Graph Neural Topic Model (GNTM) This is the pytorch implementation of the paper "Topic Modeling Revisited: A Document Graph-based Neural Network Persp

Dazhong Shen 8 Sep 14, 2022
Pseudo-rng-app - whos needs science to make a random number when you have pseudoscience?

Pseudo-random numbers with pseudoscience rng is so complicated! Why cant we have a horoscopic, vibe-y way of calculating a random number? Why cant rng

Andrew Blance 1 Dec 27, 2021
Let's create a tool to convert Thailand budget from PDF to CSV.

thailand-budget-pdf2csv Let's create a tool to convert Thailand Government Budgeting from PDF to CSV! รวมพลัง Dev แปลงงบ จาก PDF สู่ Machine-readable

Kao.Geek 88 Dec 19, 2022
A program to recognize fruits on pictures or videos using yolov5

Yolov5 Fruits Detector Requirements Either Linux or Windows. We recommend Linux for better performance. Python 3.6+ and PyTorch 1.7+. Installation To

Fateme Zamanian 30 Jan 06, 2023
Improving Contrastive Learning by Visualizing Feature Transformation, ICCV 2021 Oral

Improving Contrastive Learning by Visualizing Feature Transformation This project hosts the codes, models and visualization tools for the paper: Impro

Bingchen Zhao 83 Dec 15, 2022
Repo for Photon-Starved Scene Inference using Single Photon Cameras, ICCV 2021

Photon-Starved Scene Inference using Single Photon Cameras ICCV 2021 Arxiv Project Video Bhavya Goyal, Mohit Gupta University of Wisconsin-Madison Abs

Bhavya Goyal 5 Nov 15, 2022