Learning to Prompt for Continual Learning

Overview

Learning to Prompt for Continual Learning (L2P) Official Jax Implementation

L2P is a novel continual learning technique which learns to dynamically prompt a pre-trained model to learn tasks sequentially under different task transitions. Different from mainstream rehearsal-based or architecture-based methods, L2P requires neither a rehearsal buffer nor test-time task identity. L2P can be generalized to various continual learning settings including the most challenging and realistic task-agnostic setting. L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer.

Code is written by Zifeng Wang. Acknowledgement to https://github.com/google-research/nested-transformer.

This is not an officially supported Google product.

Enviroment setup

pip install -r requirements.txt

Getting pretrained ViT model

ViT-B/16 model used in this paper can be downloaded at here.

Instructions on running L2P

We provide the configuration file to train and evaluate L2P on multiple benchmarks in configs.

To run our method on the Split CIFAR-100 dataset (class-incremental setting):

python -m main.py --my_config configs/cifar100_l2p.py --workdir=./cifar100_l2p --my_config.init_checkpoint=<ViT-saved-path/ViT-B_16.npz>

To run our method on the more complex Gaussian Scheduled CIFAR-100 dataset (task-agnostic setting):

python -m main.py --my_config configs/cifar100_gaussian_l2p.py --workdir=./cifar100_gaussian_l2p --my_config.init_checkpoint=<ViT-saved-path/ViT-B_16.npz>

Note: we run our experiments using 8 V100 GPUs or 4 TPUs, and we specify a per device batch size of 16 in the config files. This indicates that we use a total batch size of 128.

Visualize results

We use tensorboard to visualize the result. For example, if the working directory specified to run L2P is workdir=./cifar100_l2p, the command to check result is as follows:

tensorboard --logdir ./cifar100_l2p

Here are the important metrics to keep track of, and their corresponding meanings:

Metric Description
accuracy_n Accuracy of the n-th task
forgetting Average forgetting up until the current task
avg_acc Average evaluation accuracy up until the current task

Cite

@inproceedings{wang2021learning,
  title={Learning to Prompt for Continual Learning},
  author={Zifeng Wang and Zizhao Zhang and Chen-Yu Lee and Han Zhang and Ruoxi Sun and Xiaoqi Ren and Guolong Su and Vincent Perot and Jennifer Dy and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2112.08654},
  year={2021}
}
Code for the KDD 2021 paper 'Filtration Curves for Graph Representation'

Filtration Curves for Graph Representation This repository provides the code from the KDD'21 paper Filtration Curves for Graph Representation. Depende

Machine Learning and Computational Biology Lab 16 Oct 16, 2022
TrTr: Visual Tracking with Transformer

TrTr: Visual Tracking with Transformer We propose a novel tracker network based on a powerful attention mechanism called Transformer encoder-decoder a

趙 漠居(Zhao, Moju) 66 Dec 27, 2022
Unofficial Implement PU-Transformer

PU-Transformer-pytorch Pytorch unofficial implementation of PU-Transformer (PU-Transformer: Point Cloud Upsampling Transformer) https://arxiv.org/abs/

Lee Hyung Jun 7 Sep 21, 2022
An implementation of IMLE-Net: An Interpretable Multi-level Multi-channel Model for ECG Classification

IMLE-Net: An Interpretable Multi-level Multi-channel Model for ECG Classification The repostiory consists of the code, results and data set links for

12 Dec 26, 2022
MegEngine implementation of YOLOX

Introduction YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and ind

旷视天元 MegEngine 77 Nov 22, 2022
Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters"

Manga Character Screentone Synthesis Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters" presented in IEEE ISM 2

Tsubota 2 Nov 20, 2021
MERLOT: Multimodal Neural Script Knowledge Models

merlot MERLOT: Multimodal Neural Script Knowledge Models MERLOT is a model for learning what we are calling "neural script knowledge" -- representatio

Rowan Zellers 190 Dec 22, 2022
A novel benchmark dataset for Monocular Layout prediction

AutoLay AutoLay: Benchmarking Monocular Layout Estimation Kaustubh Mani, N. Sai Shankar, J. Krishna Murthy, and K. Madhava Krishna Abstract In this pa

Kaustubh Mani 39 Apr 26, 2022
PyTorch implementations of Top-N recommendation, collaborative filtering recommenders.

PyTorch implementations of Top-N recommendation, collaborative filtering recommenders.

Yoonki Jeong 129 Dec 22, 2022
CellRank's reproducibility repository.

CellRank's reproducibility repository We believe that reproducibility is key and have made it as simple as possible to reproduce our results. Please e

Theis Lab 8 Oct 08, 2022
A python library for time-series smoothing and outlier detection in a vectorized way.

tsmoothie A python library for time-series smoothing and outlier detection in a vectorized way. Overview tsmoothie computes, in a fast and efficient w

Marco Cerliani 517 Dec 28, 2022
LSTM Neural Networks for Spectroscopic Studies of Type Ia Supernovae

Package Description The difficulties in acquiring spectroscopic data have been a major challenge for supernova surveys. snlstm is developed to provide

7 Oct 11, 2022
Official repository of the paper 'Essentials for Class Incremental Learning'

Essentials for Class Incremental Learning Official repository of the paper 'Essentials for Class Incremental Learning' This Pytorch repository contain

33 Nov 27, 2022
Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localization and Semantic Segmentation (CVPR 2022)

CCAM (Unsupervised) Code repository for our paper "CCAM: Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localizati

Computer Vision Insitute, SZU 113 Dec 27, 2022
Learning Multiresolution Matrix Factorization and its Wavelet Networks on Graphs

Project Learning Multiresolution Matrix Factorization and its Wavelet Networks on Graphs, https://arxiv.org/pdf/2111.01940.pdf. Authors Truong Son Hy

5 Jun 28, 2022
Simple (but Strong) Baselines for POMDPs

Recurrent Model-Free RL is a Strong Baseline for Many POMDPs Welcome to the POMDP world! This repo provides some simple baselines for POMDPs, specific

Tianwei V. Ni 172 Dec 29, 2022
TensorFlow implementation of "Learning from Simulated and Unsupervised Images through Adversarial Training"

Simulated+Unsupervised (S+U) Learning in TensorFlow TensorFlow implementation of Learning from Simulated and Unsupervised Images through Adversarial T

Taehoon Kim 569 Dec 29, 2022
Multi-Stage Episodic Control for Strategic Exploration in Text Games

XTX: eXploit - Then - eXplore Requirements First clone this repo using git clone https://github.com/princeton-nlp/XTX.git Please create two conda envi

Princeton Natural Language Processing 9 May 24, 2022
InterfaceGAN++: Exploring the limits of InterfaceGAN

InterfaceGAN++: Exploring the limits of InterfaceGAN Authors: Apavou Clément & Belkada Younes From left to right - Images generated using styleGAN and

Younes Belkada 42 Dec 23, 2022
[ICCV 2021 Oral] PoinTr: Diverse Point Cloud Completion with Geometry-Aware Transformers

PoinTr: Diverse Point Cloud Completion with Geometry-Aware Transformers Created by Xumin Yu*, Yongming Rao*, Ziyi Wang, Zuyan Liu, Jiwen Lu, Jie Zhou

Xumin Yu 317 Dec 26, 2022