This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

Overview

KGI (Knowledge Graph Induction) for slot filling

This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

Our model is described in: Zero-shot Slot Filling with DPR and RAG

Available from Hugging Face as:

Dataset Type Model Name Tokenizer Name
T-REx DPR (ctx) michaelrglass/dpr-ctx_encoder-multiset-base-kgi0-trex facebook/dpr-ctx_encoder-multiset-base
T-REx RAG michaelrglass/rag-token-nq-kgi0-trex rag-token-nq
zsRE DPR (ctx) michaelrglass/dpr-ctx_encoder-multiset-base-kgi0-zsre facebook/dpr-ctx_encoder-multiset-base
zsRE RAG michaelrglass/rag-token-nq-kgi0-zsre rag-token-nq

Process to reproduce

Download the KILT data and knowledge source

Segment the KILT Knowledge Source into passages:

python slot_filling/kilt_passage_corpus.py \
--kilt_corpus kilt_knowledgesource.json --output_dir kilt_passages --passage_ids passage_ids.txt

Generate the first phase of the DPR training data

python dpr/dpr_kilt_slot_filling_dataset.py \
--kilt_data structured_zeroshot-train-kilt.jsonl \
--passage_ids passage_ids.txt \
--output_file zsRE_train_positive_pids.jsonl

python dpr/dpr_kilt_slot_filling_dataset.py \
--kilt_data trex-train-kilt.jsonl \
--passage_ids passage_ids.txt \
--output_file trex_train_positive_pids.jsonl

Download and build Anserini. You will need to have Maven and a Java JDK.

git clone https://github.com/castorini/anserini.git
cd anserini
# to use the 0.4.1 version dprBM25.jar is built for
git checkout 3a60106fdc83473d147218d78ae7dca7c3b6d47c
export JAVA_HOME=your JDK directory
mvn clean package appassembler:assemble

put the title/text into the training instance with hard negatives from BM25

python dpr/anserini_prep.py \
--input kilt_passages \
--output anserini_passages

sh Anserini/target/appassembler/bin/IndexCollection -collection JsonCollection \
-generator LuceneDocumentGenerator -threads 40 -input anserini_passages \
-index anserini_passage_index -storePositions -storeDocvectors -storeRawDocs

export CLASSPATH=jar/dprBM25.jar:Anserini/target/anserini-0.4.1-SNAPSHOT-fatjar.jar
java com.ibm.research.ai.pretraining.retrieval.DPRTrainingData \
-passageIndex anserini_passage_index \
-positivePidData ${dataset}_train_positive_pids.jsonl \
-trainingData ${dataset}_dpr_training_data.jsonl

Train DPR

# multi-gpu is not well supported
export CUDA_VISIBLE_DEVICES=0

python dpr/biencoder_trainer.py \
--train_dir zsRE_dpr_training_data.jsonl \
--output_dir models/DPR/zsRE \
--num_train_epochs 2 \
--num_instances 131610 \
--encoder_gpu_train_limit 32 \
--full_train_batch_size 128 \
--max_grad_norm 1.0 --learning_rate 5e-5

python dpr/biencoder_trainer.py \
--train_dir trex_dpr_training_data.jsonl \
--output_dir models/DPR/trex \
--num_train_epochs 2 \
--num_instances 2207953 \
--encoder_gpu_train_limit 32 \
--full_train_batch_size 128 \
--max_grad_norm 1.0 --learning_rate 5e-5

Put the trained DPR query encoder into the NQ RAG model (dataset = trex, zsRE)

python dpr/prepare_rag_model.py \
--save_dir models/RAG/${dataset}_dpr_rag_init  \
--qry_encoder_path models/DPR/${dataset}/qry_encoder

Encode the passages (dataset = trex, zsRE)

python dpr/index_simple_corpus.py \
--embed 1of2 \
--dpr_ctx_encoder_path models/DPR/${dataset}/ctx_encoder \
--corpus kilt_passages  \
--output_dir kilt_passages_${dataset}

python rag/dpr/index_simple_corpus.py \
--embed 2of2 \
--dpr_ctx_encoder_path models/DPR/${dataset}/ctx_encoder \
--corpus kilt_passages \
--output_dir kilt_passages_${dataset}

Index the passage vectors (dataset = trex, zsRE)

python dpr/faiss_index.py \
--corpus_dir kilt_passages_${dataset} \
--scalar_quantizer 8 \
--output_file kilt_passages_${dataset}/index.faiss

Train RAG

python dataloader/file_splitter.py \
--input trex-train-kilt.jsonl \
--outdirs trex_training \
--file_counts 64

python slot_filling/rag_client_server_train.py \
  --kilt_data trex_training \
  --output models/RAG/trex_dpr_rag \
  --corpus_endpoint kilt_passages_trex \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/trex_dpr_rag_init \
  --num_instances 500000 --warmup_instances 10000  --num_train_epochs 1 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64


python slot_filling/rag_client_server_train.py \
  --kilt_data structured_zeroshot-train-kilt.jsonl \
  --output models/RAG/zsRE_dpr_rag \
  --corpus_endpoint kilt_passages_zsRE \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/zsRE_dpr_rag_init \
  --num_instances 147909  --warmup_instances 10000 --num_train_epochs 1 \
  --learning_rate 3e-5 --full_train_batch_size 128 --gradient_accumulation_steps 64

Apply RAG (dev_file = trex-dev-kilt.jsonl, structured_zeroshot-dev-kilt.jsonl)

python slot_filling/rag_client_server_apply.py \
  --kilt_data ${dev_file} \
  --corpus_endpoint kilt_passages_${dataset} \
  --output predictions/${dataset}_dev.jsonl \
  --model_name facebook/rag-token-nq \
  --model_path models/RAG/${dataset}_dpr_rag

python eval/convert_for_kilt_eval.py \
--apply_file predictions/${dataset}_dev.jsonl \
--eval_file predictions/${dataset}_dev_kilt_format.jsonl

Run official evaluation script

# install KILT evaluation scripts
git clone https://github.com/facebookresearch/KILT.git
cd KILT
conda create -n kilt37 -y python=3.7 && conda activate kilt37
pip install -r requirements.txt
export PYTHONPATH=`pwd`

# run evaluation
python kilt/eval_downstream.py predictions/${dataset}_dev_kilt_format.jsonl ${dev_file}
Owner
International Business Machines
International Business Machines
Official Repo for ICCV2021 Paper: Learning to Regress Bodies from Images using Differentiable Semantic Rendering

[ICCV2021] Learning to Regress Bodies from Images using Differentiable Semantic Rendering Getting Started DSR has been implemented and tested on Ubunt

Sai Kumar Dwivedi 83 Nov 27, 2022
Multi-Target Adversarial Frameworks for Domain Adaptation in Semantic Segmentation

Multi-Target Adversarial Frameworks for Domain Adaptation in Semantic Segmentation Paper Multi-Target Adversarial Frameworks for Domain Adaptation in

Valeo.ai 20 Jun 21, 2022
Real-CUGAN - Real Cascade U-Nets for Anime Image Super Resolution

Real Cascade U-Nets for Anime Image Super Resolution 中文 | English 🔥 Real-CUGAN

tarsin 111 Dec 28, 2022
A learning-based data collection tool for human segmentation

FullBodyFilter A Learning-Based Data Collection Tool For Human Segmentation Contents Documentation Source Code and Scripts Overview of Project Usage O

Robert Jiang 4 Jun 24, 2022
My tensorflow implementation of "A neural conversational model", a Deep learning based chatbot

Deep Q&A Table of Contents Presentation Installation Running Chatbot Web interface Results Pretrained model Improvements Upgrade Presentation This wor

Conchylicultor 2.9k Dec 28, 2022
MIM: MIM Installs OpenMMLab Packages

MIM provides a unified API for launching and installing OpenMMLab projects and their extensions, and managing the OpenMMLab model zoo.

OpenMMLab 254 Jan 04, 2023
🔪 Elimination based Lightweight Neural Net with Pretrained Weights

ELimNet ELimNet: Eliminating Layers in a Neural Network Pretrained with Large Dataset for Downstream Task Removed top layers from pretrained Efficient

snoop2head 4 Jul 12, 2022
Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
Cortex-compatible model server for Python and TensorFlow

Nucleus model server Nucleus is a model server for TensorFlow and generic Python models. It is compatible with Cortex clusters, Kubernetes clusters, a

Cortex Labs 14 Nov 27, 2022
Doosan robotic arm, simulation, control, visualization in Gazebo and ROS2 for Reinforcement Learning.

Robotic Arm Simulation in ROS2 and Gazebo General Overview This repository includes: First, how to simulate a 6DoF Robotic Arm from scratch using GAZE

David Valencia 12 Jan 02, 2023
Make your AirPlay devices as TTS speakers

Apple AirPlayer Home Assistant integration component, make your AirPlay devices as TTS speakers. Before Use 2021.6.X or earlier Apple Airplayer compon

George Zhao 117 Dec 15, 2022
Implementation of the Point Transformer layer, in Pytorch

Point Transformer - Pytorch Implementation of the Point Transformer self-attention layer, in Pytorch. The simple circuit above seemed to have allowed

Phil Wang 501 Jan 03, 2023
Learning nonlinear operators via DeepONet

DeepONet: Learning nonlinear operators The source code for the paper Learning nonlinear operators via DeepONet based on the universal approximation th

Lu Lu 239 Jan 02, 2023
MG-GCN: Scalable Multi-GPU GCN Training Framework

MG-GCN MG-GCN: multi-GPU GCN training framework. For more information, please read our paper. After cloning our repository, run git submodule update -

Translational Data Analytics (TDA) Lab @GaTech 6 Oct 24, 2022
All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

All course materials for the Zero to Mastery Deep Learning with TensorFlow course.

Daniel Bourke 3.4k Jan 07, 2023
Shōgun

The SHOGUN machine learning toolbox Unified and efficient Machine Learning since 1999. Latest release: Cite Shogun: Develop branch build status: Donat

Shōgun ML 2.9k Jan 04, 2023
Code for "Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance" at NeurIPS 2021

Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance Justin Lim, Christina X Ji, Michael Oberst, Saul Blecker, Leor

Sontag Lab 3 Feb 03, 2022
Automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azure

fwhr-calc-website This project is to automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azur

SoohyunPark 1 Feb 07, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
This project implements "virtual speed" from heart rate monito

ANT+ Virtual Stride Based Speed and Distance Monitor Overview This project imple

2 May 20, 2022