PiRank: Learning to Rank via Differentiable Sorting

Related tags

Deep Learningpirank
Overview

PiRank: Learning to Rank via Differentiable Sorting

This repository provides a reference implementation for learning PiRank-based models as described in the paper:

PiRank: Learning to Rank via Differentiable Sorting
Robin Swezey, Aditya Grover, Bruno Charron and Stefano Ermon.
Paper: https://arxiv.org/abs/2012.06731

Requirements

The codebase is implemented in Python 3.7. To install the necessary base requirements, run the following commands:

pip install -r requirements.txt

If you intend to use a GPU, modify requirements.txt to install tensorflow-gpu instead of tensorflow.

You will also need the NeuralSort implementation available here. Make sure it is added to your PYTHONPATH.

Datasets

PiRank was tested on the two following datasets:

Additionally, the code is expected to work with any dataset stored in the standard LibSVM format used for LTR experiments.

Scripts

There are two scripts for the code:

  • pirank_simple.py implements a simple depth-1 PiRank loss (d=1). It is used in the experiments of sections 4.1 (benchmark evaluation on MSLR-WEB30K and Yahoo! C14 datasets), 4.2.1 (effect of temperature parameter), and 4.2.2 (effect of training list size).

  • pirank_deep.py implements the deeper PiRank losses (d>=1). It is used for the experiments of section 4.2.3 and comes with a convenient synthetic data generator as well as more tuning options.

Options

Options are handled by Sacred (see Examples section below).

pirank_simple.py and pirank_deep.py

PiRank-related:

Parameter Default Value Description
loss_fn pirank_simple_loss The loss function to use (either a TFR RankingLossKey, or loss function from the script)
ste False Whether to use the Straight-Through Estimator
ndcg_k 15 [email protected] cutoff when using NS-NDCG loss

NeuralSort-related:

Parameter Default Value Description
tau 5 Temperature
taustar 1e-10 Temperature for trues and straight-through estimation.

TensorFlow-Ranking and architecture-related:

Parameter Default Value Description
hidden_layers "256,tanh,128,tanh,64,tanh" Hidden layers for an example-wise feedforward network in the format size,activation,...,size,activation
num_features 136 Number of features per document. The default value is for MSLR and depends on the dataset (e.g. for Yahoo!, please change to 700).
list_size 100 List size used for training
group_size 1 Group size used in score function

Training-related:

Parameter Default Value Description
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing
model_dir None Output directory for models
num_epochs 200 Number of epochs to train, set 0 to just test
lr 1e-4 initial learning rate
batch_size 32 The batch size for training
num_train_steps None Number of steps for training
num_vali_steps None Number of steps for validation
num_test_steps None Number of steps for testing
learning_rate 0.01 Learning rate for optimizer
dropout_rate 0.5 The dropout rate before output layer
optimizer Adagrad The optimizer for gradient descent

Sacred:

In addition, you can use regular parameters from Sacred (such as -m for logging the experiment to MongoDB).

pirank_deep.py only

Parameter Default Value Description
merge_block_size None Block size used if merging, None if not merging
top_k None Use a different Top-k for merging than final [email protected] for loss
straight_backprop False Backpropagate on scores only through NS operator
full_loss False Use the complete loss at the end of merge
tau_scheme None Which scheme to use for temperature going deeper (default: constant)
data_generator None Data generator (default: TFR\s libsvm); use this for synthetic generation
num_queries 30000 Number of queries for synthetic data generator
num_query_features 10 Number of columns used as factors for each query by synthetic data generator
actual_list_size None Size of actual list per query in synthetic data generation
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training; alternatively value of seed if using data generator
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation; alternatively value of seed if using data generator
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing; alternatively value of seed if using data generator
with_opa True Include pairwise metric OPA

Examples

Run the benchmark experiment of section 4.1 with PiRank simple loss on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with PiRank simple loss on Yahoo! C14

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/YAHOO/set1.train.txt \
    vali_path=/data/YAHOO/set1.valid.txt \
    test_path=/data/YAHOO/set1.test.txt \
    num_features=700 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with classic LambdaRank on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=lambda_rank_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the scaling ablation experiment of section 4.2.3 using synthetic data generation (d=2)

cd pirank
python3 pirank_deep.py with loss_fn=pirank_deep_loss \
    ndcg_k=10 \
    ste=True \
    merge_block_size=100 \
    tau=5 \
    taustar=1e-10 \
    tau_scheme=square \
    data_generator=synthetic_data_generator \
    actual_list_size=1000 \
    list_size=1000 \
    vali_list_size=1000 \
    test_list_size=1000 \
    full_loss=False \
    train_path=0 \
    vali_path=1 \
    test_path=2 \
    num_queries=1000 \
    num_features=25 \
    num_query_features=5 \
    hidden_layers=256,relu,256,relu,128,relu,128,relu,64,relu,64,relu \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16

Help

If you need help, reach out to Robin Swezey or raise an issue.

Citing

If you find PiRank useful in your research, please consider citing the following paper:

@inproceedings{
swezey2020pirank,
title={PiRank: Learning to Rank via Differentiable Sorting},
author={Robin Swezey and Aditya Grover and Bruno Charron and Stefano Ermon},
year={2020},
url={},
}

Torchreid: Deep learning person re-identification in PyTorch.

Torchreid Torchreid is a library for deep-learning person re-identification, written in PyTorch. It features: multi-GPU training support both image- a

Kaiyang 3.7k Jan 05, 2023
Repository for benchmarking graph neural networks

Benchmarking Graph Neural Networks Updates Nov 2, 2020 Project based on DGL 0.4.2. See the relevant dependencies defined in the environment yml files

NTU Graph Deep Learning Lab 2k Jan 03, 2023
LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs

LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs This is the code for the LERP. Dataset The dataset used is MI

5 Jun 18, 2022
Self-Supervised Pillar Motion Learning for Autonomous Driving (CVPR 2021)

Self-Supervised Pillar Motion Learning for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Self-Supervised Pillar Motion Learning for Autono

QCraft 101 Dec 05, 2022
2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation

2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation Authors: Ge-Peng Ji*, Yu-Cheng Chou*, Deng-Ping Fan, Geng Che

Ge-Peng Ji (Daniel) 85 Dec 30, 2022
Spatiotemporal resampling methods for mlr3

mlr3spatiotempcv Package website: release | dev Spatiotemporal resampling methods for mlr3. This package extends the mlr3 package framework with spati

45 Nov 21, 2022
PyTorch reimplementation of hand-biomechanical-constraints (ECCV2020)

Hand Biomechanical Constraints Pytorch Unofficial PyTorch reimplementation of Hand-Biomechanical-Constraints (ECCV2020). This project reimplement foll

Hao Meng 59 Dec 20, 2022
Check out the StyleGAN repo and place it in the same directory hierarchy as the present repo

Variational Model Inversion Attacks Kuan-Chieh Wang, Yan Fu, Ke Li, Ashish Khisti, Richard Zemel, Alireza Makhzani Most commands are in run_scripts. W

Jackson Wang 15 Dec 26, 2022
Code and model benchmarks for "SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology"

NeurIPS 2020 SEVIR Code for paper: SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology Requirement

USAF - MIT Artificial Intelligence Accelerator 46 Dec 15, 2022
Deep Learning Package based on TensorFlow

White-Box-Layer is a Python module for deep learning built on top of TensorFlow and is distributed under the MIT license. The project was started in M

YeongHyeon Park 7 Dec 27, 2021
Code for the SIGIR 2022 paper "Hybrid Transformer with Multi-level Fusion for Multimodal Knowledge Graph Completion"

MKGFormer Code for the SIGIR 2022 paper "Hybrid Transformer with Multi-level Fusion for Multimodal Knowledge Graph Completion" Model Architecture Illu

ZJUNLP 68 Dec 28, 2022
NCVX (NonConVeX): A User-Friendly and Scalable Package for Nonconvex Optimization in Machine Learning.

The source code is temporariy removed, as we are solving potential copyright and license issues with GRANSO (http://www.timmitchell.com/software/GRANS

SUN Group @ UMN 28 Aug 03, 2022
Implementation of association rules mining algorithms (Apriori|FPGrowth) using python.

Association Rules Mining Using Python Implementation of association rules mining algorithms (Apriori|FPGrowth) using python. As a part of hw1 code in

Pre 2 Nov 10, 2021
Official Pytorch implementation of the paper "Action-Conditioned 3D Human Motion Synthesis with Transformer VAE", ICCV 2021

ACTOR Official Pytorch implementation of the paper "Action-Conditioned 3D Human Motion Synthesis with Transformer VAE", ICCV 2021. Please visit our we

Mathis Petrovich 248 Dec 23, 2022
Neural Koopman Lyapunov Control

Neural-Koopman-Lyapunov-Control Code for our paper: Neural Koopman Lyapunov Control Requirements dReal4: v4.19.02.1 PyTorch: 1.2.0 The learning framew

Vrushabh Zinage 6 Dec 24, 2022
An implementation of paper `Real-time Convolutional Neural Networks for Emotion and Gender Classification` with PaddlePaddle.

简介 通过PaddlePaddle框架复现了论文 Real-time Convolutional Neural Networks for Emotion and Gender Classification 中提出的两个模型,分别是SimpleCNN和MiniXception。利用 imdb_crop

8 Mar 11, 2022
A multi-scale unsupervised learning for deformable image registration

A multi-scale unsupervised learning for deformable image registration Shuwei Shao, Zhongcai Pei, Weihai Chen, Wentao Zhu, Xingming Wu and Baochang Zha

ShuweiShao 2 Apr 13, 2022
FairMOT - A simple baseline for one-shot multi-object tracking

FairMOT - A simple baseline for one-shot multi-object tracking

Yifu Zhang 3.6k Jan 08, 2023
Toontown: Galaxy, a new Toontown game based on Disney's Toontown Online

Toontown: Galaxy The official archive repo for Toontown: Galaxy, a new Toontown

1 Feb 15, 2022
Official repository of DeMFI (arXiv.)

DeMFI This is the official repository of DeMFI (Deep Joint Deblurring and Multi-Frame Interpolation). [ArXiv_ver.] Coming Soon. Reference Jihyong Oh a

Jihyong Oh 56 Dec 14, 2022