Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Overview

Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Prerequisites

This repo is built upon a local copy of transformers==2.1.1. This repo has been tested on torch==1.4.0 with python 3.7 and CUDA 10.1.

To start, create a new environment and install:

conda create -n grad2task python=3.7
conda activate grad2task
cd Grad2Task
pip install -e .

We use wandb for logging. Please set it up following this doc and specify your project name on wandb in run_meta_training.sh:

export WANDB=[YOUR PROJECT NAME]

Download the dataset and unzip it under the main folder: https://drive.google.com/file/d/1uAdgZFYv9epk6tQVQ3SwboxFpSlkC_ZW/view?usp=sharing

If need to place it somewhere else, specify its path in path.sh.

Train & Evaluation

To train/evaluate models:

bash meta_learn.sh [MODEL_NAME] [MODE] [EXP_ID]

where [MODEL_NAME] refers to model name, [MODE] is experiment model and [EXP_ID] is an optional experiment id used for mark different runs using the same model. Options for [MODEL_NAM] and MODE are listed as follow:

[MODE] Description
train Training models.
test_best Test the model with the best validation performance.
test_latest Test the latest checkpoint.
test Test model without meta-training. Only applicable to the fine-tune-baseline model.
[MODEL_NAME] Description
fine-tune-baseline Fine-tuning BERT for each task separately.
bert-protonet-euc ProtoNet with BERT as encoder, using Euclidean distance as distance metric.
bert-protonet-euc-bn ProtoNet with BERT+Bottleneck Adapters as encoder, using Euclidean distance as distance metric.
bert-protonet ProtoNet with BERT as encoder, using cosine distance as distance metric.
bert-protonet-bn ProtoNet with BERT+Bottleneck Adapters as encoder, using cosine distance as distance metric.
bert-leopard Leopard with pretrained BERT [1].
bert-leopard-fixlr Leopard but with fixed learning rates.
bert-cnap-bn-euc-context-cls-shift-scale-ar Our proposed approach using gradients as task representation.
bert-cnap-bn-euc-context-cls-shift-scale-ar-X Our proposed approach using average input encoding as task representation.
bert-cnap-bn-euc-context-cls-shift-scale-ar-XGrad Our proposed approach using both gradients and input encoding as task representation.
bert-cnap-bn-euc-context-cls-shift-scale-ar-XY Our proposed approach using input and textual label encoding as task representation.
bert-cnap-bn-euc-context-shift-scale-ar Same with our proposed approach except adapting all tokens instead of just the [CLS] token as we do.
bert-cnap-bn-pretrained-taskemb Our proposed approach with pretrained task embedding model.
bert-cnap-bn-hyper A hypernetwork based approach.

To run a model with different hyperparameters, first name this run by [EXP_ID] and then specify the new hyperparameters in run/meta_learn.sh. For example, if one wants to run bert-protonet-euc with a smaller learning rate, they could modify run/meta_learn.sh as:

...
elif [ $1 == "bert-protonet-bn" ]; then # ProtoNet with cosince distance
    export LEARNING_RATE=2e-5
    export CHECKPOINT_FREQ=1000
    if [ ${EXP_ID} == *"lr1e-5" ]; then
        export LEARNING_RATE=1e-5
        export CHECKPOINT_FREQ=2000
        # modify other hyperparameters here
    fi
...

and then run:

bash meta_learn.sh bert-protonet-bn train lr1e-5

Reference

[1] T. Bansal, R. Jha, and A. McCallum. Learning to few-shot learn across diverse natural language classification tasks. In Proceedings of the 28th International Conference on Computational Linguistics, pages 5108–5123, 2020.

Owner
Jixuan Wang
Computer Science PhD student at University of Toronto. Research interests include deep learning and machine learning, and their applications in healthcare.
Jixuan Wang
Implementation for Learning to Track with Object Permanence

Learning to Track with Object Permanence A video-based MOT approach capable of tracking through full occlusions: Learning to Track with Object Permane

Toyota Research Institute - Machine Learning 91 Jan 03, 2023
Identifying Stroke Indicators Using Rough Sets

Identifying Stroke Indicators Using Rough Sets With the spirit of reproducible research, this repository contains all the codes required to produce th

Muhammad Salman Pathan 0 Jun 09, 2022
DM-ACME compatible implementation of the Arm26 environment from Mujoco

ACME-compatible implementation of Arm26 from Mujoco This repository contains a customized implementation of Mujoco's Arm26 model, that can be used wit

1 Dec 24, 2021
A machine learning malware analysis framework for Android apps.

🕵️ A machine learning malware analysis framework for Android apps. ☢️ DroidDetective is a Python tool for analysing Android applications (APKs) for p

James Stevenson 77 Dec 27, 2022
Pocsploit is a lightweight, flexible and novel open source poc verification framework

Pocsploit is a lightweight, flexible and novel open source poc verification framework

cckuailong 208 Dec 24, 2022
Yolov5-lite - Minimal PyTorch implementation of YOLOv5

Yolov5-Lite: Minimal YOLOv5 + Deep Sort Overview This repo is a shortened versio

Kadir Nar 57 Nov 28, 2022
A toolkit for making real world machine learning and data analysis applications in C++

dlib C++ library Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real worl

Davis E. King 11.6k Jan 01, 2023
🦕 NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano

🦕 nanosaur NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano Website: nanosaur.ai Do you need an help? Discord For tech

NanoSaur 162 Dec 09, 2022
LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs

LERP : Label-dependent and event-guided interpretable disease risk prediction using EHRs This is the code for the LERP. Dataset The dataset used is MI

5 Jun 18, 2022
An implementation of the paper "A Neural Algorithm of Artistic Style"

A Neural Algorithm of Artistic Style implementation - Neural Style Transfer This is an implementation of the research paper "A Neural Algorithm of Art

Srijarko Roy 27 Sep 20, 2022
Code for the bachelors-thesis flaky fault localization

Flaky_Fault_Localization Scripts for the Bachelors-Thesis: "Flaky Fault Localization" by Christian Kasberger. The thesis examines the usefulness of sp

Christian Kasberger 1 Oct 26, 2021
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
Notspot robot simulation - Python version

Notspot robot simulation - Python version This repository contains all the files and code needed to simulate the notspot quadrupedal robot using Gazeb

50 Sep 26, 2022
GEP (GDB Enhanced Prompt) - a GDB plug-in for GDB command prompt with fzf history search, fish-like autosuggestions, auto-completion with floating window, partial string matching in history, and more!

GEP (GDB Enhanced Prompt) GEP (GDB Enhanced Prompt) is a GDB plug-in which make your GDB command prompt more convenient and flexibility. Why I need th

Alan Li 23 Dec 21, 2022
Weakly- and Semi-Supervised Panoptic Segmentation (ECCV18)

Weakly- and Semi-Supervised Panoptic Segmentation by Qizhu Li*, Anurag Arnab*, Philip H.S. Torr This repository demonstrates the weakly supervised gro

Qizhu Li 159 Dec 20, 2022
Implementation EfficientDet: Scalable and Efficient Object Detection in PyTorch

Implementation EfficientDet: Scalable and Efficient Object Detection in PyTorch

tonne 1.4k Dec 29, 2022
A paper using optimal transport to solve the graph matching problem.

GOAT A paper using optimal transport to solve the graph matching problem. https://arxiv.org/abs/2111.05366 Repo structure .github: Files specifying ho

neurodata 8 Jan 04, 2023
This is the implementation of the paper "Self-supervised Outdoor Scene Relighting"

Self-supervised Outdoor Scene Relighting This is the implementation of the paper "Self-supervised Outdoor Scene Relighting". The model is implemented

Ye Yu 24 Dec 17, 2022
Graph WaveNet apdapted for brain connectivity analysis.

Graph WaveNet for brain network analysis This is the implementation of the Graph WaveNet model used in our manuscript: S. Wein , A. Schüller, A. M. To

4 Dec 17, 2022
Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)

GraspNet Baseline Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). [paper] [dataset] [API] [do

GraspNet 209 Dec 29, 2022