code for the ICLR'22 paper: On Robust Prefix-Tuning for Text Classification

Overview

On Robust Prefix-Tuning for Text Classification

Prefix-tuning has drawed much attention as it is a parameter-efficient and modular alternative to adapting pretrained language models to downstream tasks. However, we find that prefix-tuning suffers from adversarial attacks. While, unfortunately, current robust NLP methods are unsuitable for prefix-tuning as they will inevitably hamper the modularity of prefix-tuning. In our ICLR'22 paper, we propose robust prefix-tuning for text classification. Our method leverages the idea of test-time tuning, which preserves the strengths of prefix-tuning and improves its robustness at the same time. This repository contains the code for the proposed robust prefix-tuning method.

Prerequisite

PyTorch>=1.2.0, pytorch-transformers==1.2.0, OpenAttack==2.0.1, and GPUtil==1.4.0.

Train the original prefix P_θ

For the training phase of standard prefix-tuning, the command is:

  source train.sh --preseqlen [A] --learning_rate [B] --tasks [C] --n_train_epochs [D] --device [E]

where

  • [A]: The length of the prefix P_θ.
  • [B]: The (initial) learning rate.
  • [C]: The benchmark. Default: sst.
  • [D]: The total epochs during training.
  • [E]: The id of the GPU to be used.

We can also use adversarial training to improve the robustness of the prefix. For the training phase of adversarial prefix-tuning, the command is:

  source train_adv.sh --preseqlen [A] --learning_rate [B] --tasks [C] --n_train_epochs [D] --device [E] --pgd_ball [F]

where

  • [A]~[E] have the same meanings with above.
  • [F]: where norm ball is word-wise or sentence-wise.

Note that the DATA_DIR and MODEL_DIR in train_adv.sh are different from those in train.sh. When experimenting with the adversarially trained prefix P_θ's in the following steps, remember to switch the DATA_DIR and MODEL_DIR in the corresponding scripts as well.

Generate Adversarial Examples

We use the OpenAttack package to generate in-sentence adversaries. The command is:

  source generate_adv_insent.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack [H]

where

  • [A],[B],[C],[E] have the same meanings with above.
  • [G]: Load the prefix P_θ parameters trained for [G] epochs for testing. We set G=D.
  • [H]: Generate adversarial examples based on clean test set with the in-sentence attack [H].

We also implement the Universal Adversarial Trigger attack. The command is:

  source generate_adv_uat.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack clean-[H2] --uat_len [I] --uat_epoch [J]

where

  • [A],[B],[C],[E],[G] have the same meanings with above.
  • [H2]: We should search for UATs for each class in the benchmark, and H2 indicates the class id. H2=0/1 for SST, 0/1/2/3 for AG News, and 0/1/2 for SNLI.
  • [I]: The length of the UAT.
  • [J]: The epochs for exploiting UAT.

Test the performance of P_θ

The command for performance testing of P_θ under clean data and in-sentence attacks is:

  source test_prefix_theta_insent.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack [H] --test_batch_size [K]

Under UAT attack, the test command is:

  source test_prefix_theta_uat.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack clean --uat_len [I] --test_batch_size [K]

where

  • [A]~[I] have the same meanings with above.
  • [K]: The test batch size. when K=0, the batch size is adaptive (determined by GPU memory); when K>0, the batch size is fixed.

Robust Prefix P'_ψ: Constructing the canonical manifolds

By constructing the canonical manifolds with PCA, we get the projection matrices. The command is:

  source get_proj.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G]

where [A]~[G] have the same meanings with above.

Robust Prefix P'_ψ: Test its performance

Under clean data and in-sentence attacks, the command is:

  source test_robust_prefix_psi_insent.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack [H] --test_batch_size [K] --PMP_lr [L] --PMP_iter [M]

Under UAT attack, the test command is:

  source test_robust_prefix_psi_uat.sh --preseqlen [A] --learning_rate [B] --tasks [C] --device [E] --test_ep [G] --attack clean --uat_len [I] --test_batch_size [K] --PMP_lr [L] --PMP_iter [M]

where

  • [A]~[K] have the same meanings with above.
  • [L]: The learning rate for test-time P'_ψ tuning.
  • [M]: The iterations for test-time P'_ψ tuning.

Running Example

# Train the original prefix P_θ
source train.sh --tasks sst --n_train_epochs 100 --device 0
source train_adv.sh --tasks sst --n_train_epochs 100 --device 1 --pgd_ball word

# Generate Adversarial Examples
source generate_adv_insent.sh --tasks sst --device 0 --test_ep 100 --attack bug
source generate_adv_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean-0 --uat_len 3 --uat_epoch 10
source generate_adv_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean-1 --uat_len 3 --uat_epoch 10

# Test the performance of P_θ
source test_prefix_theta_insent.sh --tasks sst --device 0 --test_ep 100 --attack bug --test_batch_size 0
source test_prefix_theta_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean --uat_len 3 --test_batch_size 0

# Robust Prefix P'_ψ: Constructing the canonical manifolds
source get_proj.sh --tasks sst --device 0 --test_ep 100

# Robust Prefix P'_ψ: Test its performance
source test_robust_prefix_psi_insent.sh --tasks sst --device 0 --test_ep 100 --attack bug --test_batch_size 0 --PMP_lr 0.15 --PMP_iter 10
source test_robust_prefix_psi_uat.sh --tasks sst --device 0 --test_ep 100 --attack clean --uat_len 3 --test_batch_size 0 --PMP_lr 0.05 --PMP_iter 10

Released Data & Models

The training the original prefix P_θ and the process of generating adversarial examples can be time-consuming. As shown in our paper, the adversarial prefix-tuning is particularly slow. Efforts need to be paid on generating adversaries as well, since different attacks are to be performed on the test set based on each trained prefix. We also found that OpenAttack is now upgraded to v2.1.1, which causes compatibility issues in our codes (test_prefix_theta_insent.py).

In order to facilitate research on the robustness of prefix-tuning, we release the prefix checkpoints P_θ (with both std. and adv. training), the processed test sets that are perturbed by in-sentence attacks (including PWWS and TextBugger), as well as the generated projection matrices of the canonical manifolds in our runs for reproducibility and further enhancement. We have also hard-coded the exploited UAT tokens in test_prefix_theta_uat.py and test_robust_prefix_psi_uat.py. All the materials can be found here.

Acknowledgements:

The implementation of robust prefix tuning is based on the LAMOL repo, which is the code of LAMOL: LAnguage MOdeling for Lifelong Language Learning that studies NLP lifelong learning with GPT-style pretrained language models.

Bibtex

If you find this repository useful for your research, please consider citing our work:

@inproceedings{
  yang2022on,
  title={On Robust Prefix-Tuning for Text Classification},
  author={Zonghan Yang and Yang Liu},
  booktitle={International Conference on Learning Representations},
  year={2022},
  url={https://openreview.net/forum?id=eBCmOocUejf}
}
Owner
Zonghan Yang
Graduate student in Tsinghua University. Two drifters, off to see the world - there's such a lot of world to see...
Zonghan Yang
Official implementation for "Low-light Image Enhancement via Breaking Down the Darkness"

Low-light Image Enhancement via Breaking Down the Darkness by Qiming Hu, Xiaojie Guo. 1. Dependencies Python3 PyTorch=1.0 OpenCV-Python, TensorboardX

Qiming Hu 30 Jan 01, 2023
Continual Learning of Electronic Health Records (EHR).

Continual Learning of Longitudinal Health Records Repo for reproducing the experiments in Continual Learning of Longitudinal Health Records (2021). Re

Jacob 7 Oct 21, 2022
PyTorch code for the "Deep Neural Networks with Box Convolutions" paper

Box Convolution Layer for ConvNets Single-box-conv network (from `examples/mnist.py`) learns patterns on MNIST What This Is This is a PyTorch implemen

Egor Burkov 515 Dec 18, 2022
Process text, including tokenizing and representing sentences as vectors and Applying some concepts like RNN, LSTM and GRU to create a classifier can detect the language in which a sentence is written from among 17 languages.

Language Identifier What is this ? The goal of this project is to create a model that is able to predict a given sentence language through text proces

Hossam Asaad 9 Dec 15, 2022
Fast sparse deep learning on CPUs

SPARSEDNN **If you want to use this repo, please send me an email: [email pro

Ziheng Wang 44 Nov 30, 2022
[CVPR 2021] 'Searching by Generating: Flexible and Efficient One-Shot NAS with Architecture Generator'

[CVPR2021] Searching by Generating: Flexible and Efficient One-Shot NAS with Architecture Generator Overview This is the entire codebase for the paper

35 Dec 01, 2022
Linear algebra python - Number of operations and problems in Linear Algebra and Numerical Linear Algebra

Linear algebra in python Number of operations and problems in Linear Algebra and

Alireza 5 Oct 09, 2022
[CoRL 2021] A robotics benchmark for cross-embodiment imitation.

x-magical x-magical is a benchmark extension of MAGICAL specifically geared towards cross-embodiment imitation. The tasks still provide the Demo/Test

Kevin Zakka 36 Nov 26, 2022
MIMO-UNet - Official Pytorch Implementation

MIMO-UNet - Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Rethinking Coarse-to-

Sungjin Cho 248 Jan 02, 2023
QR2Pass-project - A proof of concept for an alternative (passwordless) authentication system to a web server

QR2Pass This is a proof of concept for an alternative (passwordless) authenticat

4 Dec 09, 2022
Light-weight network, depth estimation, knowledge distillation, real-time depth estimation, auxiliary data.

light-weight-depth-estimation Boosting Light-Weight Depth Estimation Via Knowledge Distillation, https://arxiv.org/abs/2105.06143 Junjie Hu, Chenyou F

Junjie Hu 13 Dec 10, 2022
Official PyTorch Implementation of Unsupervised Learning of Scene Flow Estimation Fusing with Local Rigidity

UnRigidFlow This is the official PyTorch implementation of UnRigidFlow (IJCAI2019). Here are two sample results (~10MB gif for each) of our unsupervis

Liang Liu 28 Nov 16, 2022
A framework for multi-step probabilistic time-series/demand forecasting models

JointDemandForecasting.py A framework for multi-step probabilistic time-series/demand forecasting models File stucture JointDemandForecasting contains

Stanford Intelligent Systems Laboratory 3 Sep 28, 2022
A Light in the Dark: Deep Learning Practices for Industrial Computer Vision

A Light in the Dark: Deep Learning Practices for Industrial Computer Vision This is the repository for our Paper/Contribution to the WI2022 in Nürnber

Maximilian Harl 6 Jan 17, 2022
The implement of papar "Enhanced Graph Learning for Collaborative Filtering via Mutual Information Maximization"

SIGIR2021-EGLN The implement of paper "Enhanced Graph Learning for Collaborative Filtering via Mutual Information Maximization" Neural graph based Col

15 Dec 27, 2022
Deep Learning for Natural Language Processing SS 2021 (TU Darmstadt)

Deep Learning for Natural Language Processing SS 2021 (TU Darmstadt) Task Training huge unsupervised deep neural networks yields to strong progress in

2 Aug 05, 2022
NaijaSenti is an open-source sentiment and emotion corpora for four major Nigerian languages

NaijaSenti is an open-source sentiment and emotion corpora for four major Nigerian languages. This project was supported by lacuna-fund initiatives. Jump straight to one of the sections below, or jus

Hausa Natural Language Processing 14 Dec 20, 2022
JAXMAPP: JAX-based Library for Multi-Agent Path Planning in Continuous Spaces

JAXMAPP: JAX-based Library for Multi-Agent Path Planning in Continuous Spaces JAXMAPP is a JAX-based library for multi-agent path planning (MAPP) in c

OMRON SINIC X 24 Dec 28, 2022
A framework for Quantification written in Python

QuaPy QuaPy is an open source framework for quantification (a.k.a. supervised prevalence estimation, or learning to quantify) written in Python. QuaPy

41 Dec 14, 2022
这是一个unet-pytorch的源码,可以训练自己的模型

Unet:U-Net: Convolutional Networks for Biomedical Image Segmentation目标检测模型在Pytorch当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Downl

Bubbliiiing 567 Jan 05, 2023