Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Overview

Efficient Nearest Neighbor Language Models

This is implementation of the paper:

Efficient Nearest Neighbor Language Models
Junxian He, Graham Neubig, Taylor Berg-Kirkpatrick
EMNLP 2021

This repo implements several techniques to speed up the evaluation of non-parametric, nearest neighbor language models. Specifically, we improve the efficiency along three axes: adaptive retrieval, datastore prunning, and dimension reduction.

Install Dependencies

This repository is largly based on the knnlm repo which is a fork of Fairseq (commit da544b). Please use the exact commit page to determine software requirements for using this code.

git clone [email protected]:jxhe/efficient-knnlm.git

cd efficient-knnlm
pip install --editable .
pip install faiss

Hardware

Experiments for this paper were conducted on machines that contain 32 CPUs, 100GB of RAM, and one NVIDIA 3090 24GB GPU. Saving the Wikitext-103 datastore requires 200GB of disk space. Note that the number of CPUs has a great impact on the speed.

Running Efficient kNNLM

Preparation

Data

We share Fairseq's instructions on how to prepare the data here.

mkdir -p datasets/wikitext-103
cp examples/language_model/wikitext-103/prepare-wikitext-103.sh datasets/wikitext-103

cd datasets/wikitext-103
bash prepare-wikitext-103.sh
cd ../..

TEXT=datasets/wikitext-103
python preprocess.py \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Download the language model checkpoint pretrained on WikiText-103

# the model checkpoint link is from the knnlm repo
wget https://nlp.stanford.edu/projects/knnlm/wt103_checkpoint_best.pt -P knnlm_ckpt

Save the datastore

mkdir -p dstore

python eval_lm.py data-bin/wikitext-103 \
    --path knnlm_ckpt/checkpoint_best.pt \
    --sample-break-mode none --max-tokens 3072 \
    --softmax-batch 1024 --gen-subset train \
    --context-window 1536 --tokens-per-sample 1536 \
    --dstore-mmap dstore/dstore --knn-keytype 'last_ffn_input' \
    --dstore-size 103225485 --model-overrides "{'knn_keytype': 'last_ffn_input'}" \
    --save-knnlm-dstore --fp16 --dstore-fp16

Dimension Reduction

# the script applies PCA of dimension 512 by default 
# the PCA hyperparameter can be tuned in this script
# set pca=0 to revert back to the vanilla version
bash ef_knnlm/build_faiss.sh

The faiss index is saved into dstore. Try it out:

bash ef_knnlm/utils_cmd/eval_knnlm.sh \
    -d wikitext-103 \
    -s valid \
    -p dstore/dstore_size103225485_embed1024_fp16 \
    -i dstore/knn.103225485.pca512.m64.index \
    -n 103225485 \

You should already observe a speedup.

Adaptive Retrieval

prepare heldout data to train the retrieval adaptor

# this randomly selects 90% of validation data as the training data to 
# train the retrieval adaptor
bash ef_knnlm/adaptive_retrieval/prepare_heldout.sh wikitext-103

prepare features

bash ef_knnlm/adaptive_retrieval/prepare_feature_pipeline.sh

train

bash ef_knnlm/adaptive_retrieval/train_ar.sh

It saves the retrieval adaptor checkpoints into checkpoint/wikitext-103-valid

evaluation

# the cutoff ratio in adaptive retrieval
# by default we cut off half of the retrieval
cutoff=50

# please change this to the .pt file path observed from the last step
ar_ckpt=xxx

# this hyperparameter needs to be changed if 
# the datastore sizes change (e.g. datastore pruning)
size=103225485

dstore_prefix=dstore/dstore_size${size}_embed1024_fp16
index_file=dstore/knn.${size}.pca512.m64.index

bash ef_knnlm/utils_cmd/eval_knnlm.sh \
    -d wikitext-103 \
    -s test \
    -p ${dstore_prefix} \
    -i ${index_file} \
    -c knnlm_ckpt/wt103_checkpoint_best.pt \
    -n ${size} \
    -f datasets/wikitext-103 \
    -a ctxt,freq,lm_ent,lm_max,fert \
    -u ${cutoff} \
    -h ${ar_ckpt} \
    # -w "True"

Datastore Pruning

precompute all the retrieval results for every record in the datastore:

# It is possible to parallel this operation by change 
# "--start-point" and "--num" arguments so that the training
# data would be splitted into multiple smaller ones. In this case
# the retrieval results would be saved into multiple files
bash ef_knnlm/dstore_compression/save_retrieval_results.sh

The retrieval results are saved into dstore/greedy_merge, other datastore pruning algorithms may be played around using these pre-computed results.

greedy merging

# perform greedy merging to yield a new smaller datastore, 
# and build faiss index from the new datastore
bash ef_knnlm/dstore_compression/merge_compression.sh

The pruned datastore and index are saved into dstore/greedy_merging, replace the previousdstore_prefix/index_file with the new ones to use the pruned the datastore. The option -w "True"needs to be passed to eval_knnlm.sh to read the generated datastore weights file from greedy merging.

Reference

@inproceedings{he2021eff,
title={Efficient Nearest Neighbor Language Models},
author={Junxian He and Graham Neubig and Taylor Berg-Kirkpatrick},
booktitle={Proceedings of EMNLP},
year={2021}
}
Owner
Junxian He
NLP/ML PhD student at CMU
Junxian He
Tensorflow implementation of ID-Unet: Iterative Soft and Hard Deformation for View Synthesis.

ID-Unet: Iterative-view-synthesis(CVPR2021 Oral) Tensorflow implementation of ID-Unet: Iterative Soft and Hard Deformation for View Synthesis. Overvie

17 Aug 23, 2022
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
Vehicle Detection Using Deep Learning and YOLO Algorithm

VehicleDetection Vehicle Detection Using Deep Learning and YOLO Algorithm Dataset take or find vehicle images for create a special dataset for fine-tu

Maryam Boneh 96 Jan 05, 2023
Greedy Gaussian Segmentation

GGS Greedy Gaussian Segmentation (GGS) is a Python solver for efficiently segmenting multivariate time series data. For implementation details, please

Stanford University Convex Optimization Group 72 Dec 07, 2022
[WACV 2022] Contextual Gradient Scaling for Few-Shot Learning

CxGrad - Official PyTorch Implementation Contextual Gradient Scaling for Few-Shot Learning Sanghyuk Lee, Seunghyun Lee, and Byung Cheol Song In WACV 2

Sanghyuk Lee 4 Dec 05, 2022
scikit-learn inspired API for CRFsuite

sklearn-crfsuite sklearn-crfsuite is a thin CRFsuite (python-crfsuite) wrapper which provides interface simlar to scikit-learn. sklearn_crfsuite.CRF i

417 Dec 20, 2022
Neural network-based build time estimation for additive manufacturing

Neural network-based build time estimation for additive manufacturing Oh, Y., Sharp, M., Sprock, T., & Kwon, S. (2021). Neural network-based build tim

Yosep 1 Nov 15, 2021
AI-generated-characters for Learning and Wellbeing

AI-generated-characters for Learning and Wellbeing Click here for the full project page. This repository contains the source code for the paper AI-gen

MIT Media Lab 214 Jan 01, 2023
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

DongGeun-Yoon 19 Jun 07, 2022
[CVPR2022] Representation Compensation Networks for Continual Semantic Segmentation

RCIL [CVPR2022] Representation Compensation Networks for Continual Semantic Segmentation Chang-Bin Zhang1, Jia-Wen Xiao1, Xialei Liu1, Ying-Cong Chen2

Chang-Bin Zhang 71 Dec 28, 2022
The official implementation of paper Siamese Transformer Pyramid Networks for Real-Time UAV Tracking, accepted by WACV22

SiamTPN Introduction This is the official implementation of the SiamTPN (WACV2022). The tracker intergrates pyramid feature network and transformer in

Robotics and Intelligent Systems Control @ NYUAD 28 Nov 25, 2022
Official code for "Mean Shift for Self-Supervised Learning"

MSF Official code for "Mean Shift for Self-Supervised Learning" Requirements Python = 3.7.6 PyTorch = 1.4 torchvision = 0.5.0 faiss-gpu = 1.6.1 In

UMBC Vision 44 Nov 21, 2022
ZeroVL - The official implementation of ZeroVL

This repository contains source code necessary to reproduce the results presente

31 Nov 04, 2022
PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)

1-bit Wide ResNet PyTorch implementation of training 1-bit Wide ResNets from this paper: Training wide residual networks for deployment using a single

Sergey Zagoruyko 122 Dec 07, 2022
This is a pytorch implementation for the BST model from Alibaba https://arxiv.org/pdf/1905.06874.pdf

Behavior-Sequence-Transformer-Pytorch This is a pytorch implementation for the BST model from Alibaba https://arxiv.org/pdf/1905.06874.pdf This model

Jaime Ferrando Huertas 83 Jan 05, 2023
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Analysing poker data from home games with friends

Poker Game Analysis Analysing poker data from home games with friends. Not a lot of data is collected, so this project is primarily focussed on descri

Stavros Karmaniolos 1 Oct 15, 2022
Object classification with basic computer vision techniques

naive-image-classification Object classification with basic computer vision techniques. Final assignment for the computer vision course I took at univ

2 Jul 01, 2022
The comma.ai Calibration Challenge!

Welcome to the comma.ai Calibration Challenge! Your goal is to predict the direction of travel (in camera frame) from provided dashcam video. This rep

comma.ai 697 Jan 05, 2023
Tool cek opsi checkpoint facebook!

tool apa ini? cek_opsi_facebook adalah sebuah tool yang mengecek opsi checkpoint akun facebook yang terkena checkpoint! tujuan dibuatnya tool ini? too

Muhammad Latif Harkat 2 Jul 17, 2022