The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Overview

Dice Loss for NLP Tasks

This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020.

Setup

  • Install Package Dependencies

The code was tested in Python 3.6.9+ and Pytorch 1.7.1. If you are working on ubuntu GPU machine with CUDA 10.1, please run the following command to setup environment.

$ virtualenv -p /usr/bin/python3.6 venv
$ source venv/bin/activate
$ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
  • Download BERT Model Checkpoints

Before running the repo you must download the BERT-Base and BERT-Large checkpoints from here and unzip it to some directory $BERT_DIR. Then convert original TensorFlow checkpoints for BERT to a PyTorch saved file by running bash scripts/prepare_ckpt.sh <path-to-unzip-tf-bert-checkpoints>.

Apply Dice-Loss to NLP Tasks

In this repository, we apply dice loss to four NLP tasks, including

  1. machine reading comprehension
  2. paraphrase identification task
  3. named entity recognition
  4. text classification

1. Machine Reading Comprehension

Datasets

We take SQuAD 1.1 as an example. Before training, you should download a copy of the data from here.
And move the SQuAD 1.1 train train-v1.1.json and dev file dev-v1.1.json to the directory $DATA_DIR.

Train

We choose BERT as the backbone. During training, the task trainer BertForQA will automatically evaluate on dev set every $val_check_interval epoch, and save the dev predictions into files called $OUTPUT_DIR/predictions_<train-epoch>_<total-train-step>.json and $OUTPUT_DIR/nbest_predictions_<train-epoch>_<total-train-step>.json.

Run scripts/squad1/bert_<model-scale>_<loss-type>.sh to reproduce our experimental results.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [bce, focal, dice] which denotes fine-tuning BERT-Base with binary cross entropy loss, focal loss, dice loss , respectively.

  • Run bash scripts/squad1/bert_base_focal.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for focal loss.

  • Run bash scripts/squad1/bert_base_dice.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for dice loss.

Evaluate

To evaluate a model checkpoint, please run

python3 tasks/squad/evaluate_models.py \
--gpus="1" \
--path_to_model_checkpoint  $OUTPUT_DIR/epoch=2.ckpt \
--eval_batch_size <evaluate-batch-size>

After evaluation, prediction results predictions_dev.json and nbest_predictions_dev.json can be found in $OUTPUT_DIR

To evaluate saved predictions, please run

python3 tasks/squad/evaluate_predictions.py <path-to-dev-v1.1.json> <directory-to-prediction-files>

2. Paraphrase Identification Task

Datasets

We use MRPC (GLUE Version) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Run bash scripts/prepare_mrpc_data.sh $DATA_DIR to download and process datasets for MPRC (GLUE Version) task.

Train

Please run scripts/glue_mrpc/bert_<model-scale>_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint which achieves the highest F1-score on the dev set.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/glue_mrpc/bert_large_focal.sh for focal loss.

  • Run bash scripts/glue_mrpc/bert_large_dice.sh for dice loss.

The evaluation results on the dev and test set are saved at $OUTPUT_DIR/eval_result_log.txt file.
The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Evaluate

To evaluate a model checkpoint on test set, please run

bash scripts/glue_mrpc/eval.sh \
$OUTPUT_DIR \
epoch=*.ckpt

3. Named Entity Recognition

For NER, we use MRC-NER model as the backbone.
Processed datasets and model architecture can be found here.

Train

Please run scripts/<ner-datdaset-name>/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint.
The variable <ner-dataset-name> should take the value of [ner_enontonotes5, ner_zhmsra, ner_zhonto4].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

For Chinese MSRA,

  • Run scripts/ner_zhmsra/bert_focal.sh for focal loss.

  • Run scripts/ner_zhmsra/bert_dice.sh for dice loss.

For Chinese OntoNotes4,

  • Run scripts/ner_zhonto4/bert_focal.sh for focal loss.

  • Run scripts/ner_zhonto4/bert_dice.sh for dice loss.

For English OntoNotes5,

  • Run scripts/ner_enontonotes5/bert_focal.sh. After training, you will get 91.12 Span-F1 on the test set.

  • Run scripts/ner_enontonotes5/bert_dice.sh. After training, you will get 92.01 Span-F1 on the test set.

Evaluate

To evaluate a model checkpoint, please run

CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \
--gpus="1" \
--path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt

4. Text Classification

Datasets

We use TNews (Chinese Text Classification) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Train

We choose BERT as the backbone.
Please run scripts/tnews/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/tnews/bert_focal.sh for focal loss.

  • Run bash scripts/tnews/bert_dice.sh for dice loss.

The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Citation

If you find this repository useful , please cite the following:

@article{li2019dice,
  title={Dice loss for data-imbalanced NLP tasks},
  author={Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Liang, Junjun and Wu, Fei and Li, Jiwei},
  journal={arXiv preprint arXiv:1911.02855},
  year={2019}
}

Contact

xiaoyalixy AT gmail.com OR xiaoya_li AT shannonai.com

Any discussions, suggestions and questions are welcome!

Breaking the Dilemma of Medical Image-to-image Translation

Breaking the Dilemma of Medical Image-to-image Translation Supervised Pix2Pix and unsupervised Cycle-consistency are two modes that dominate the field

Kid Liet 86 Dec 21, 2022
Least Square Calibration for Peer Reviews

Least Square Calibration for Peer Reviews Requirements gurobipy - for solving convex programs GPy - for Bayesian baseline numpy pandas To generate p

Sigma <a href=[email protected]"> 1 Nov 01, 2021
Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.

pgmpy pgmpy is a python library for working with Probabilistic Graphical Models. Documentation and list of algorithms supported is at our official sit

pgmpy 2.2k Jan 03, 2023
PyTorch code for the "Deep Neural Networks with Box Convolutions" paper

Box Convolution Layer for ConvNets Single-box-conv network (from `examples/mnist.py`) learns patterns on MNIST What This Is This is a PyTorch implemen

Egor Burkov 515 Dec 18, 2022
An open source object detection toolbox based on PyTorch

MMDetection is an open source object detection toolbox based on PyTorch. It is a part of the OpenMMLab project.

Bo Chen 24 Dec 28, 2022
Time series annotation library.

CrowdCurio Time Series Annotator Library The CrowdCurio Time Series Annotation Library implements classification tasks for time series. Features Suppo

CrowdCurio 51 Sep 15, 2022
PyTorch implementation of CVPR'18 - Perturbative Neural Networks

This is an attempt to reproduce results in Perturbative Neural Networks paper. See original repo for details.

Michael Klachko 57 May 14, 2021
Data-driven reduced order modeling for nonlinear dynamical systems

SSMLearn Data-driven Reduced Order Models for Nonlinear Dynamical Systems This package perform data-driven identification of reduced order model based

Haller Group, Nonlinear Dynamics 27 Dec 13, 2022
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
Extract MNIST handwritten digits dataset binary file into bmp images

MNIST-dataset-extractor Extract MNIST handwritten digits dataset binary file into bmp images More info at http://yann.lecun.com/exdb/mnist/ Dependenci

Omar Mostafa 6 May 24, 2021
Physics-Informed Neural Networks (PINN) and Deep BSDE Solvers of Differential Equations for Scientific Machine Learning (SciML) accelerated simulation

NeuralPDE NeuralPDE.jl is a solver package which consists of neural network solvers for partial differential equations using scientific machine learni

SciML Open Source Scientific Machine Learning 680 Jan 02, 2023
A simple tutoral for error correction task, based on Pytorch

gramcorrector A simple tutoral for error correction task, based on Pytorch Grammatical Error Detection (sentence-level) a binary sequence-based classi

peiyuan_gong 8 Dec 03, 2022
Diffusion Normalizing Flow (DiffFlow) Neurips2021

Diffusion Normalizing Flow (DiffFlow) Reproduce setup environment The repo heavily depends on jam, a personal toolbox developed by Qsh.zh. The API may

76 Jan 01, 2023
An implementation of the methods presented in Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

An implementation of the methods presented in Causal-BALD: Deep Bayesian Active Learning of Outcomes to Infer Treatment-Effects from Observational Data.

Andrew Jesson 9 Apr 04, 2022
A distributed, plug-n-play algorithm for multi-robot applications with a priori non-computable objective functions

A distributed, plug-n-play algorithm for multi-robot applications with a priori non-computable objective functions Kapoutsis, A.C., Chatzichristofis,

Athanasios Ch. Kapoutsis 5 Oct 15, 2022
The implementation code for "DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruction"

DAGAN This is the official implementation code for DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruct

TensorLayer Community 159 Nov 22, 2022
Unsupervised Foreground Extraction via Deep Region Competition

Unsupervised Foreground Extraction via Deep Region Competition [Paper] [Code] The official code repository for NeurIPS 2021 paper "Unsupervised Foregr

28 Nov 06, 2022
[NeurIPS 2021 Spotlight] Code for Learning to Compose Visual Relations

Learning to Compose Visual Relations This is the pytorch codebase for the NeurIPS 2021 Spotlight paper Learning to Compose Visual Relations. Demo Imag

Nan Liu 88 Jan 04, 2023
Skipgram Negative Sampling in PyTorch

PyTorch SGNS Word2Vec's SkipGramNegativeSampling in Python. Yet another but quite general negative sampling loss implemented in PyTorch. It can be use

Jamie J. Seol 287 Dec 14, 2022
Face recognition. Redefined.

FaceFinder Use a powerful CNN to identify faces in images! TABLE OF CONTENTS About The Project Built With Getting Started Prerequisites Installation U

BleepLogger 20 Jun 16, 2021