PiRank: Learning to Rank via Differentiable Sorting

Related tags

Deep Learningpirank
Overview

PiRank: Learning to Rank via Differentiable Sorting

This repository provides a reference implementation for learning PiRank-based models as described in the paper:

PiRank: Learning to Rank via Differentiable Sorting
Robin Swezey, Aditya Grover, Bruno Charron and Stefano Ermon.
Paper: https://arxiv.org/abs/2012.06731

Requirements

The codebase is implemented in Python 3.7. To install the necessary base requirements, run the following commands:

pip install -r requirements.txt

If you intend to use a GPU, modify requirements.txt to install tensorflow-gpu instead of tensorflow.

You will also need the NeuralSort implementation available here. Make sure it is added to your PYTHONPATH.

Datasets

PiRank was tested on the two following datasets:

Additionally, the code is expected to work with any dataset stored in the standard LibSVM format used for LTR experiments.

Scripts

There are two scripts for the code:

  • pirank_simple.py implements a simple depth-1 PiRank loss (d=1). It is used in the experiments of sections 4.1 (benchmark evaluation on MSLR-WEB30K and Yahoo! C14 datasets), 4.2.1 (effect of temperature parameter), and 4.2.2 (effect of training list size).

  • pirank_deep.py implements the deeper PiRank losses (d>=1). It is used for the experiments of section 4.2.3 and comes with a convenient synthetic data generator as well as more tuning options.

Options

Options are handled by Sacred (see Examples section below).

pirank_simple.py and pirank_deep.py

PiRank-related:

Parameter Default Value Description
loss_fn pirank_simple_loss The loss function to use (either a TFR RankingLossKey, or loss function from the script)
ste False Whether to use the Straight-Through Estimator
ndcg_k 15 [email protected] cutoff when using NS-NDCG loss

NeuralSort-related:

Parameter Default Value Description
tau 5 Temperature
taustar 1e-10 Temperature for trues and straight-through estimation.

TensorFlow-Ranking and architecture-related:

Parameter Default Value Description
hidden_layers "256,tanh,128,tanh,64,tanh" Hidden layers for an example-wise feedforward network in the format size,activation,...,size,activation
num_features 136 Number of features per document. The default value is for MSLR and depends on the dataset (e.g. for Yahoo!, please change to 700).
list_size 100 List size used for training
group_size 1 Group size used in score function

Training-related:

Parameter Default Value Description
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing
model_dir None Output directory for models
num_epochs 200 Number of epochs to train, set 0 to just test
lr 1e-4 initial learning rate
batch_size 32 The batch size for training
num_train_steps None Number of steps for training
num_vali_steps None Number of steps for validation
num_test_steps None Number of steps for testing
learning_rate 0.01 Learning rate for optimizer
dropout_rate 0.5 The dropout rate before output layer
optimizer Adagrad The optimizer for gradient descent

Sacred:

In addition, you can use regular parameters from Sacred (such as -m for logging the experiment to MongoDB).

pirank_deep.py only

Parameter Default Value Description
merge_block_size None Block size used if merging, None if not merging
top_k None Use a different Top-k for merging than final [email protected] for loss
straight_backprop False Backpropagate on scores only through NS operator
full_loss False Use the complete loss at the end of merge
tau_scheme None Which scheme to use for temperature going deeper (default: constant)
data_generator None Data generator (default: TFR\s libsvm); use this for synthetic generation
num_queries 30000 Number of queries for synthetic data generator
num_query_features 10 Number of columns used as factors for each query by synthetic data generator
actual_list_size None Size of actual list per query in synthetic data generation
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training; alternatively value of seed if using data generator
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation; alternatively value of seed if using data generator
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing; alternatively value of seed if using data generator
with_opa True Include pairwise metric OPA

Examples

Run the benchmark experiment of section 4.1 with PiRank simple loss on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with PiRank simple loss on Yahoo! C14

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/YAHOO/set1.train.txt \
    vali_path=/data/YAHOO/set1.valid.txt \
    test_path=/data/YAHOO/set1.test.txt \
    num_features=700 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with classic LambdaRank on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=lambda_rank_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the scaling ablation experiment of section 4.2.3 using synthetic data generation (d=2)

cd pirank
python3 pirank_deep.py with loss_fn=pirank_deep_loss \
    ndcg_k=10 \
    ste=True \
    merge_block_size=100 \
    tau=5 \
    taustar=1e-10 \
    tau_scheme=square \
    data_generator=synthetic_data_generator \
    actual_list_size=1000 \
    list_size=1000 \
    vali_list_size=1000 \
    test_list_size=1000 \
    full_loss=False \
    train_path=0 \
    vali_path=1 \
    test_path=2 \
    num_queries=1000 \
    num_features=25 \
    num_query_features=5 \
    hidden_layers=256,relu,256,relu,128,relu,128,relu,64,relu,64,relu \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16

Help

If you need help, reach out to Robin Swezey or raise an issue.

Citing

If you find PiRank useful in your research, please consider citing the following paper:

@inproceedings{
swezey2020pirank,
title={PiRank: Learning to Rank via Differentiable Sorting},
author={Robin Swezey and Aditya Grover and Bruno Charron and Stefano Ermon},
year={2020},
url={},
}

Learning with Noisy Labels via Sparse Regularization, ICCV2021

Learning with Noisy Labels via Sparse Regularization This repository is the official implementation of [Learning with Noisy Labels via Sparse Regulari

Xiong Zhou 38 Oct 20, 2022
The repository forked from NVlabs uses our data. (Differentiable rasterization applied to 3D model simplification tasks)

nvdiffmodeling [origin_code] Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Autom

Qiujie (Jay) Dong 2 Oct 31, 2022
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022
Exploiting a Zoo of Checkpoints for Unseen Tasks

Exploiting a Zoo of Checkpoints for Unseen Tasks This repo includes code to reproduce all results in the above Neurips paper, authored by Jiaji Huang,

Baidu Research 8 Sep 06, 2022
LRBoost is a scikit-learn compatible approach to performing linear residual based stacking/boosting.

LRBoost is a sckit-learn compatible package for linear residual boosting. LRBoost combines a linear estimator and a non-linear estimator to leverage t

Andrew Patton 5 Nov 23, 2022
Do you like Quick, Draw? Well what if you could train/predict doodles drawn inside Streamlit? Also draws lines, circles and boxes over background images for annotation.

Streamlit - Drawable Canvas Streamlit component which provides a sketching canvas using Fabric.js. Features Draw freely, lines, circles, boxes and pol

Fanilo Andrianasolo 325 Dec 28, 2022
[ACL-IJCNLP 2021] Improving Named Entity Recognition by External Context Retrieving and Cooperative Learning

CLNER The code is for our ACL-IJCNLP 2021 paper: Improving Named Entity Recognition by External Context Retrieving and Cooperative Learning CLNER is a

71 Dec 08, 2022
Genshin-assets - 👧 Public documentation & static assets for Genshin Impact data.

genshin-assets This repo provides easy access to the Genshin Impact assets, primarily for use on static sites. Sources Genshin Optimizer - An Artifact

Zerite Development 5 Nov 22, 2022
Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT

CheXbert: Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT CheXbert is an accurate, automated dee

Stanford Machine Learning Group 51 Dec 08, 2022
A framework for analyzing computer vision models with simulated data

3DB: A framework for analyzing computer vision models with simulated data Paper Quickstart guide Blog post Installation Follow instructions on: https:

3DB 112 Jan 01, 2023
Code for Emergent Translation in Multi-Agent Communication

Emergent Translation in Multi-Agent Communication PyTorch implementation of the models described in the paper Emergent Translation in Multi-Agent Comm

Facebook Research 75 Jul 15, 2022
Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

567 Dec 26, 2022
DiffWave is a fast, high-quality neural vocoder and waveform synthesizer.

DiffWave DiffWave is a fast, high-quality neural vocoder and waveform synthesizer. It starts with Gaussian noise and converts it into speech via itera

LMNT 498 Jan 03, 2023
code and data for paper "GIANT: Scalable Creation of a Web-scale Ontology"

GIANT Code and data for paper "GIANT: Scalable Creation of a Web-scale Ontology" https://arxiv.org/pdf/2004.02118.pdf Please cite our paper if this pr

Excalibur 39 Dec 29, 2022
MonoRec: Semi-Supervised Dense Reconstruction in Dynamic Environments from a Single Moving Camera

MonoRec: Semi-Supervised Dense Reconstruction in Dynamic Environments from a Single Moving Camera

Felix Wimbauer 494 Jan 06, 2023
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
Out-of-Town Recommendation with Travel Intention Modeling (AAAI2021)

TrainOR_AAAI21 This is the official implementation of our AAAI'21 paper: Haoran Xin, Xinjiang Lu, Tong Xu, Hao Liu, Jingjing Gu, Dejing Dou, Hui Xiong

Jack Xin 13 Oct 19, 2022
Multi-Anchor Active Domain Adaptation for Semantic Segmentation (ICCV 2021 Oral)

Multi-Anchor Active Domain Adaptation for Semantic Segmentation Munan Ning*, Donghuan Lu*, Dong Wei†, Cheng Bian, Chenglang Yuan, Shuang Yu, Kai Ma, Y

Munan Ning 36 Dec 07, 2022
The codes I made while I practiced various TensorFlow examples

TensorFlow_Exercises The codes I made while I practiced various TensorFlow examples About the codes I didn't create these codes by myself, but re-crea

Terry Taewoong Um 614 Dec 08, 2022