Repo for WWW 2022 paper: Progressively Optimized Bi-Granular Document Representation for Scalable Embedding Based Retrieval

Related tags

Deep LearningBiDR
Overview

BiDR

Repo for WWW 2022 paper: Progressively Optimized Bi-Granular Document Representation for Scalable Embedding Based Retrieval.

Requirements

torch==1.7
transformers==4.6
faiss-gpu==1.6.4.post2

Data Download and Preprocess

bash download_data.sh
python preprocess.py

These commands will download and preprocess the MSMARCO Passage and Doc dataset, then the resutls will be saved to ./data.
We take the Passage dataset as the example to show the running workflow.

Conventional Workflow

Representation Learning

Train the encoder with random negative (or set --hardneg_json to provied bm25/hard negatives) :

mkdir log
dataset=passage
savename=dense_global_model
python train.py --model_name_or_path roberta-base \
--max_query_length 24 --max_doc_length 128 \
--data_dir ./data/${dataset}/preprocess \
--learning_rate 1e-4 --optimizer_str adamw \
--per_device_train_batch_size 128 \
--per_query_neg_num 1 \
--generate_batch_method random \
--loss_method multi_ce  \
--savename ${savename} --save_model_path ./model \
--world_size 8 --gpu_rank 0_1_2_3_4_5_6_7  --master_port 13256 \
--num_train_epochs 30  \
--use_pq False \
|tee ./log/${savename}.log

Unsupervised Quantization

Generate dense embeddings of queries and docs:

data_type=passage
savename=dense_global_model
epoch=20
python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} 

Product Quantization based on Faiss and recall performance:

data_type=passage
savename=dense_global_model
epoch=20
python ./test/lightweight_ann.py \
--output_dir ./data/${data_type}/evaluate/${savename}_${epoch} \
--ckpt_path /model/${savename}/${epoch}/ \
--subvector_num 96 \
--index opq \
--topk 1000 \
--data_type ${data_type} \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100

Progressively Optimized Bi-Granular Document Representation

Sparse Representation Learning

Instead of running unsupervised quantization for the well-learned dense embeddings, the sparse embeddings are generated from contrastive learning, which optimizes the global discrimination and helps to enable high-quality answers to be covered in candidate search.

Train

We find that using Faiss OPQ to initialize the PQ module has a significant gain for MSMARCO dataset. But for the largest dataset: Ads dataset, initialization with Faiss OPQ is redundant and has no promotion.

dataset=passage
savename=sparse_global_model
python train.py --model_name_or_path ./model/dense_global_model/20 \
--max_query_length 24 --max_doc_length 128 \
--data_dir ./data/${dataset}/preprocess \
--learning_rate 1e-4 --optimizer_str adamw \
--per_device_train_batch_size 128 \
--per_query_neg_num 1 \
--generate_batch_method random \
--loss_method multi_ce  \
--savename ${savename} --save_model_path ./model \
--world_size 8 --gpu_rank 0_1_2_3_4_5_6_7  --master_port 13256 \
--num_train_epochs 30  \
--use_pq True \
--init_index_path ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index \
--partition 96 --centroids 256 --quantloss_weight 0.0 \
|tee ./log/${savename}.log

where the ./model/dense_global_model/20 and ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index is generated by conventional workflow.

Test

data_type=passage
savename=sparse_global_model
epoch=20

python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} 

python ./test/lightweight_ann.py \
--output_dir ./data/${data_type}/evaluate/${savename}_${epoch} \
--subvector_num 96 \
--index opq \
--topk 1000 \
--data_type ${data_type} \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100 \
--ckpt_path ./model/${savename}/${epoch}/ \
--init_index_path ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index

Dense Representation Learning

The dense embeddings are optimized based on the candidate distribution generated by sparse embeddings. We propose a novel sampling strategy called locality-centric sampling to enhance local discrimination: construct a bipartite proximity graph and conduct random walk or snow sample on it.

Train

Encode the quries in train set and generate the candidates for all train queries:

data_type=passage
savename=sparse_global_model
epoch=20

python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} \
--mode train

python ./test/lightweight_ann.py \
--output_dir ./data/${data_type}/evaluate/${savename}_${epoch} \
--subvector_num 96 \
--index opq \
--topk 1000 \
--data_type ${data_type} \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100 \
--ckpt_path ./model/${savename}/${epoch}/ \
--init_index_path ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index \
--mode train \
--save_hardneg_to_json

This command will save the train_hardneg.json to output_dir. Then train the dense embeddings to distinguish the ground truth from the negative in candidate:

dataset=passage
savename=dense_local_model
python train.py --model_name_or_path roberta-base \
--max_query_length 24 --max_doc_length 128 \
--data_dir ./data/${dataset}/preprocess \
--learning_rate 1e-4 --optimizer_str adamw \
--per_device_train_batch_size 128 \
--per_query_neg_num 1 \
--generate_batch_method {random_walk or snow_sample} \
--loss_method multi_ce  \
--savename ${savename} --save_model_path ./model \
--world_size 8 --gpu_rank 0_1_2_3_4_5_6_7  --master_port 13256 \
--num_train_epochs 30  \
--use_pq False \
--hardneg_json ./data/${data_type}/evaluate/sparse_global_model_20/train_hardneg.json \
--mink 0  --maxk 200 \
|tee ./log/${savename}.log

Test

data_type=passage
savename=dense_local_model
epoch=10

python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--ckpt_path ./model/${savename}/${epoch}/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} 

python ./test/post_verification.py \
--data_type ${data_type} \
--output_dir  evaluate/${savename}_${epoch} \
--candidate_from_ann ./data/${data_type}/evaluate/sparse_global_model_20/dev.rank_1000_score_faiss_opq.tsv \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Highly comparative time-series analysis

〰️ hctsa 〰️ : highly comparative time-series analysis hctsa is a software package for running highly comparative time-series analysis using Matlab (fu

Ben Fulcher 569 Dec 21, 2022
Using Machine Learning to Create High-Res Fine Art

BIG.art: Using Machine Learning to Create High-Res Fine Art How to use GLIDE and BSRGAN to create ultra-high-resolution paintings with fine details By

Robert A. Gonsalves 13 Nov 27, 2022
FFTNet vocoder implementation

Unofficial Implementation of FFTNet vocode paper. implement the model. implement tests. overfit on a single batch (sanity check). linearize weights fo

Eren Gölge 81 Dec 08, 2022
Age and Gender prediction using Keras

cnn_age_gender Age and Gender prediction using Keras Dataset example : Description : UTKFace dataset is a large-scale face dataset with long age span

XN3UR0N 58 May 03, 2022
A PyTorch Toolbox for Face Recognition

FaceX-Zoo FaceX-Zoo is a PyTorch toolbox for face recognition. It provides a training module with various supervisory heads and backbones towards stat

JDAI-CV 1.6k Jan 06, 2023
PlenOctrees: NeRF-SH Training & Conversion

PlenOctrees Official Repo: NeRF-SH training and conversion This repository contains code to train NeRF-SH and to extract the PlenOctree, constituting

Alex Yu 323 Dec 29, 2022
Code for 1st place solution in Sleep AI Challenge SNU Hospital

Sleep AI Challenge SNU Hospital 2021 Code for 1st place solution for Sleep AI Challenge (Note that the code is not fully organized) Refer to the notio

Saewon Yang 13 Jan 03, 2022
Lightweight, Python library for fast and reproducible experimentation :microscope:

Steppy What is Steppy? Steppy is a lightweight, open-source, Python 3 library for fast and reproducible experimentation. Steppy lets data scientist fo

minerva.ml 134 Jul 10, 2022
Official implementation of CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21

CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21 For more information, check out the paper on [arXiv]. Training with different

Sunghwan Hong 120 Jan 04, 2023
DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data.

DWIPrep: A Robust Preprocessing Pipeline for dMRI Data DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data. The transp

Gal Ben-Zvi 1 Jan 09, 2023
Automatic Idiomatic Expression Detection

IDentifier of Idiomatic Expressions via Semantic Compatibility (DISC) An Idiomatic identifier that detects the presence and span of idiomatic expressi

5 Jun 09, 2022
Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation"

SharinGAN Official repo for the work titled "SharinGAN: Combining Synthetic and Real Data for Unsupervised GeometryEstimation" The official project we

Koutilya PNVR 23 Oct 19, 2022
Official implementation of the paper 'High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network' in CVPR 2021

LPTN Paper | Supplementary Material | Poster High-Resolution Photorealistic Image Translation in Real-Time: A Laplacian Pyramid Translation Network Ji

372 Dec 26, 2022
QTool: A Low-bit Quantization Toolbox for Deep Neural Networks in Computer Vision

This project provides abundant choices of quantization strategies (such as the quantization algorithms, training schedules and empirical tricks) for quantizing the deep neural networks into low-bit c

Monash Green AI Lab 51 Dec 10, 2022
Source code for paper: Knowledge Inheritance for Pre-trained Language Models

Knowledge-Inheritance Source code paper: Knowledge Inheritance for Pre-trained Language Models (preprint). The trained model parameters (in Fairseq fo

THUNLP 31 Nov 19, 2022
Code for "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks"

Note: this repo has been discontinued, please check code for newer version of the paper here Weight Normalized GAN Code for the paper "On the Effects

Sitao Xiang 182 Sep 06, 2021
"Segmenter: Transformer for Semantic Segmentation" reproduced via mmsegmentation

Segmenter-based-on-OpenMMLab "Segmenter: Transformer for Semantic Segmentation, arxiv 2105.05633." reproduced via mmsegmentation. We reproduce Segment

EricKani 22 Feb 24, 2022
A Pose Estimator for Dense Reconstruction with the Structured Light Illumination Sensor

Phase-SLAM A Pose Estimator for Dense Reconstruction with the Structured Light Illumination Sensor This open source is written by MATLAB Run Mode Open

Xi Zheng 14 Dec 19, 2022
Contains code for the paper "Vision Transformers are Robust Learners".

Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin

Sayak Paul 103 Jan 05, 2023
PySLM Python Library for Selective Laser Melting and Additive Manufacturing

PySLM Python Library for Selective Laser Melting and Additive Manufacturing PySLM is a Python library for supporting development of input files used i

Dr Luke Parry 35 Dec 27, 2022