The Codebase for Causal Distillation for Language Models.

Overview

Python 3.7 License CC BY-NC

Causal Distillation for Language Models

Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman

The is an implementation of our preprint Causal Distillation for Language Models. The standard approach to distillation trains a student model against two objectives: a task-specific objective (e.g., language modeling) and an imitation objective that encourages the hidden states of the student model to be similar to those of the larger teacher model. In this paper, we show that it is beneficial to augment distillation with a third objective that encourages the student to imitate the causal computation process of the teacher through interchange intervention training (IIT).

We fork our main codebase from the Huggingface Distillation Interface.

Release Notes

12/02/2021 Our paper on Interchange Intervention Training (IIT) is released! Read this more formal definition of the method.
12/06/2021 Released the causal distillation codebase with the preprint.
12/06/2021 Released evaluation results on distilled tiny-BERT (3 layers) with the Wiki-Text 103M dataset.
⬜️ Released evaluation results on causal-distilled tiny-BERT (3 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released evaluation results on causal-distilled BERT (6 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released more ablation studies.
⬜️ Released causal-distilled tiny-BERT (3 layers) model files.
⬜️ Released causal-distilled BERT (6 layers) model files.

If you experience any issues or have suggestions, please contact me either thourgh the issues page or at [email protected].

Benchmark Results

Here are the results on the dev sets of GLUE:

Model Average-score CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B WNLI
DistilBERT (3 layers) 67.81 22.8 71.6 78.2 82.1 84.3 55.4 86.5 56.7 24.2
CausalBERT (3 layers) 69.71 25.0 72.9 78.6 83.1 84.9 55.4 86.9 66.5 21.5

1 Average-score computed without WNLI.

Main Contents

Citation

If you use this repository, please cite the following two papers: paper for interchange intervention training, and paper for the our distillation method.

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{wu-etal-2021-distill,
        title={Causal Distillation for Language Models}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
        year={2021},
        eprint={2112.02505},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.9.0
  • Transfermers Version: 4.11.3
  • Datasets Version: Version: 1.8.0
  • We have performed experiments on Titan V GPU. We assume 12GB of GPU memory (more memory can expedite training).
  • Since we build our codebase off the Huggingface Distillation Interface, please review their doc for requirements.

Dataset

Following the Huggingface Distillation Interface, we need to pre-process the datasets before we do distillation. You can refer to their repo for details. We adapt their pre-processing scripts, and update with a few improvements. For example, we can now binarize datasets from the Dataset Hub from huggingface directly.

# preprocessing from disk
python script/binarized_data.py \
--file_path ../../bert-mid-tuning/data-files/wikitext-15M \
--split train \
--field_name text \
--max_parsing_example 1000 \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file ./data/binarized_text

# preprocessing from huggingface.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext+bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

# helper scripts to combine two binarized data files
python scripts/data_combinator.py \
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle \
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle \
--split train \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text

# multiprocessing preprocessor.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/ \
--fast_process \
--preprocessing_num_workers 48

After you get the datasets ready, you need to generate token counts as well.

python scripts/token_counts.py \
--data_file data/binarized_text.train.bert-base-uncased.pickle \
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle \
--vocab_size 30522

Distillation

Before training, we recommand you to initialize your student model with weights extracted from the teacher model.

python scripts/extract_distilbert.py \
--model_type bert \
--model_name bert-base-uncased \
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--num_layers 3

Now, here is an example for you to distill with our causal distillation objective or without,

CUDA_VISIBLE_DEVICES=9,4 python causal_train.py \
--force \
--n_gpu 2 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 50 \
--n_epoch 3 \
--batch_size 5

CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py \
--force \
--n_gpu 4 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --alpha_causal 0.00 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 124 \
--n_epoch 6 \
--batch_size 4

Note that you can simply turn our causal distillation objective on/off through setting the arguments.

Evaluation

After you get your distilled models, you need to fine-tune them and evaluate them with downstream tasks. We provide you all the scripts you need to run.

MLM Evaluation

CUDA_VISIBLE_DEVICES=5 python run_mlm.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-15M_seed_42_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer/ \
--dataset_dir ../../bert-mid-tuning/data-files/wikitext-15M/ \
--tokenizer_name bert-base-uncased \
--do_eval \
--output_dir /tmp/test-mlm \
--cache_dir ./distill_cache/

GLUE Evaluation

CUDA_VISIBLE_DEVICES=5,7,8,9 python run_glue.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle/ \
--tokenizer_name bert-base-uncased \
--task_name sst2 \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir ./results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

CoNLL Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_ner.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name conll2003 \
--do_train \
--do_eval \
--output_dir ./ner_results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

SQuAD Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_qa.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--save_total_limit 1 \
--output_dir ./qa_results/
This repo is to be freely used by ML devs to check the GAN performances without coding from scratch.

GANs for Fun Created because I can! GOAL The goal of this repo is to be freely used by ML devs to check the GAN performances without coding from scrat

Sagnik Roy 13 Jan 26, 2022
chen2020iros: Learning an Overlap-based Observation Model for 3D LiDAR Localization.

Overlap-based 3D LiDAR Monte Carlo Localization This repo contains the code for our IROS2020 paper: Learning an Overlap-based Observation Model for 3D

Photogrammetry & Robotics Bonn 219 Dec 15, 2022
[ArXiv 2021] One-Shot Generative Domain Adaptation

GenDA - One-Shot Generative Domain Adaptation One-Shot Generative Domain Adaptation Ceyuan Yang*, Yujun Shen*, Zhiyi Zhang, Yinghao Xu, Jiapeng Zhu, Z

GenForce: May Generative Force Be with You 46 Dec 19, 2022
Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Mingrui Yu 3 Jan 07, 2022
A Python package for faster, safer, and simpler ML processes

Bender 🤖 A Python package for faster, safer, and simpler ML processes. Why use bender? Bender will make your machine learning processes, faster, safe

Otovo 6 Dec 13, 2022
Fre-GAN: Adversarial Frequency-consistent Audio Synthesis

Fre-GAN Vocoder Fre-GAN: Adversarial Frequency-consistent Audio Synthesis Training: python train.py --config config.json Citation: @misc{kim2021frega

Rishikesh (ऋषिकेश) 93 Dec 17, 2022
Bolt Online Learning Toolbox

Bolt Online Learning Toolbox Bolt features discriminative learning of linear predictors (e.g. SVM or Logistic Regression) using fast online learning a

Peter Prettenhofer 87 Dec 12, 2022
Learning to Map Large-scale Sparse Graphs on Memristive Crossbar

Release of AutoGMap:Learning to Map Large-scale Sparse Graphs on Memristive Crossbar For reproduction of our searched model, the Ubuntu OS is recommen

2 Aug 23, 2022
Implementation of paper "Decision-based Black-box Attack Against Vision Transformers via Patch-wise Adversarial Removal"

Patch-wise Adversarial Removal Implementation of paper "Decision-based Black-box Attack Against Vision Transformers via Patch-wise Adversarial Removal

4 Oct 12, 2022
A code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Vanderhaeghe, and Yotam Gingold from SIGGRAPH Asia 2020.

A Benchmark for Rough Sketch Cleanup This is the code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Va

33 Dec 18, 2022
Optimal space decomposition based-product quantization for approximate nearest neighbor search

Optimal space decomposition based-product quantization for approximate nearest neighbor search Abstract Product quantization(PQ) is an effective neare

Mylove 1 Nov 19, 2021
SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning

SPCL SPCL: A New Framework for Domain Adaptive Semantic Segmentation via Semantic Prototype-based Contrastive Learning Update on 2021/11/25: ArXiv Ver

Binhui Xie (谢斌辉) 11 Oct 29, 2022
Learning to Prompt for Vision-Language Models.

CoOp Paper: Learning to Prompt for Vision-Language Models Authors: Kaiyang Zhou, Jingkang Yang, Chen Change Loy, Ziwei Liu CoOp (Context Optimization)

Kaiyang 679 Jan 04, 2023
A Collection of LiDAR-Camera-Calibration Papers, Toolboxes and Notes

A Collection of LiDAR-Camera-Calibration Papers, Toolboxes and Notes

443 Jan 06, 2023
Fully convolutional networks for semantic segmentation

FCN-semantic-segmentation Simple end-to-end semantic segmentation using fully convolutional networks [1]. Takes a pretrained 34-layer ResNet [2], remo

Kai Arulkumaran 186 Dec 25, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
Riemannian Geometry for Molecular Surface Approximation (RGMolSA)

Riemannian Geometry for Molecular Surface Approximation (RGMolSA) Introduction Ligand-based virtual screening aims to reduce the cost and duration of

11 Nov 15, 2022
ACV is a python library that provides explanations for any machine learning model or data.

ACV is a python library that provides explanations for any machine learning model or data. It gives local rule-based explanations for any model or data and different Shapley Values for tree-based mod

Salim Amoukou 85 Dec 27, 2022
SMORE: Knowledge Graph Completion and Multi-hop Reasoning in Massive Knowledge Graphs

SMORE: Knowledge Graph Completion and Multi-hop Reasoning in Massive Knowledge Graphs SMORE is a a versatile framework that scales multi-hop query emb

Google Research 135 Dec 27, 2022
YOLOX Win10 Project

Introduction 这是一个用于Windows训练YOLOX的项目,相比于官方项目,做了一些适配和修改: 1、解决了Windows下import yolox失败,No such file or directory: 'xxx.xml'等路径问题 2、CUDA out of memory等显存不

5 Jun 08, 2022