Efficient Lottery Ticket Finding: Less Data is More

Overview

Efficient Lottery Ticket Finding: Less Data is More

License: MIT

Codes for this paper Efficient Lottery Ticket Finding: Less Data is More. [ICML 2021]

Zhenyu Zhang*, Xuxi Chen*, Tianlong Chen*, Zhangyang Wang

Overview

The lottery ticket hypothesis (LTH) reveals the existence of winning tickets (sparse but critical subnetworks) for dense networks, that can be trained in isolation from random initialization to match the latter’s accuracies. However, finding winning tickets requires burdensome computations in the train-prune-retrain process, especially on large-scale datasets (e.g., ImageNet), restricting their practical benefits. This paper explores a new perspective on finding lottery tickets more efficiently, by doing so only with a specially selected subset of data, called Pruning- Aware Critical set (PrAC set), rather than using the full training set. The concept of PrAC set was inspired by the recent observation, that deep networks have samples that are either hard to memorize during training, or easy to forget during pruning. A PrAC set is thus hypothesized to capture those most challenging and informative examples for the dense model. We observe that a high-quality winning ticket can be found with training and pruning the dense network on the very compact PrAC set, which can substantially save training iterations for the ticket finding process.

Prerequisites

Pytorch >= 1.4

torchvision

advertorch

Usage

Vanilla Lottery Tickets

python -u main_imp.py \
	--data data/cifar10 \
	--dataset cifar10 \
	--arch res20s \
	--batch_size 128 \
	--lr 0.1 \
	--pruning_times 16 \
	--prune_type rewind_lt \
	--rewind_epoch 2 \
	--save_dir lt_cifar10_res20s

PrAC Lottery Tickets

python -u main_PrAC_imp.py \
	--data data/cifar10 \
	--dataset cifar10 \
	--arch res20s \
	--split_file npy_files/cifar10-train-val.npy \
	--batch_size 128 \
	--lr 0.1 \
	--pruning_times 16 \
	--eb_eps 0.08 \
	--prune_type rewind_lt \
	--rewind_epoch 2 \
	--threshold 0 \
	--save_dir PrAC_lt_cifar10_res20s
	

Train subnetworks

python -u main_train.py \
	--data data/cifar10 \
	--dataset cifar10 \
	--arch res20s \
	--batch_size 128 \
	--lr 0.1 \
	--init_dir PrAC_lt_cifar10_res20s/1checkpoint.pth.tar \ 
	--mask_dir PrAC_lt_cifar10_res20s/1checkpoint.pth.tar \ # sparsity=20%
	--save_dir retrain_PrAC_lt_cifar10_res20s/1

Citation


Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Official PyTorch implementation of CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds

CAPTRA: CAtegory-level Pose Tracking for Rigid and Articulated Objects from Point Clouds Introduction This is the official PyTorch implementation of o

Yijia Weng 96 Dec 07, 2022
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

MDCA Calibration 21 Dec 22, 2022
This repository contains the code for the paper ``Identifiable VAEs via Sparse Decoding''.

Sparse VAE This repository contains the code for the paper ``Identifiable VAEs via Sparse Decoding''. Data Sources The datasets used in this paper wer

Gemma Moran 17 Dec 12, 2022
An implementation of Deep Graph Infomax (DGI) in PyTorch

DGI Deep Graph Infomax (Veličković et al., ICLR 2019): https://arxiv.org/abs/1809.10341 Overview Here we provide an implementation of Deep Graph Infom

Petar Veličković 491 Jan 03, 2023
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
Neural network pruning for finding a sparse computational model for controlling a biological motor task.

MothPruning Scientific Overview Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for mo

Olivia Thomas 0 Dec 14, 2022
An official PyTorch implementation of the TKDE paper "Self-Supervised Graph Representation Learning via Topology Transformations".

Self-Supervised Graph Representation Learning via Topology Transformations This repository is the official PyTorch implementation of the following pap

Hsiang Gao 2 Oct 31, 2022
Machine Learning Privacy Meter: A tool to quantify the privacy risks of machine learning models with respect to inference attacks, notably membership inference attacks

ML Privacy Meter Machine learning is playing a central role in automated decision making in a wide range of organization and service providers. The da

Data Privacy and Trustworthy Machine Learning Research Lab 357 Jan 06, 2023
[CVPR2021 Oral] FFB6D: A Full Flow Bidirectional Fusion Network for 6D Pose Estimation.

FFB6D This is the official source code for the CVPR2021 Oral work, FFB6D: A Full Flow Biderectional Fusion Network for 6D Pose Estimation. (Arxiv) Tab

Yisheng (Ethan) He 201 Dec 28, 2022
A new test set for ImageNet

ImageNetV2 The ImageNetV2 dataset contains new test data for the ImageNet benchmark. This repository provides associated code for assembling and worki

186 Dec 18, 2022
Large-scale Hyperspectral Image Clustering Using Contrastive Learning, CIKM 21 Workshop

Spectral-spatial contrastive clustering (SSCC) Yaoming Cai, Yan Liu, Zijia Zhang, Zhihua Cai, and Xiaobo Liu, Large-scale Hyperspectral Image Clusteri

Yaoming Cai 4 Nov 02, 2022
Implementing a simplified copy of Shazam application from scratch using MinHashing and LSH.

Building Shazam from scratch In this repository we tried to implement a simplified copy of the Shazam application able to tell you the name of a song

Arturo Ghinassi 0 Nov 17, 2022
Fast algorithms to compute an approximation of the minimal volume oriented bounding box of a point cloud in 3D.

ApproxMVBB Status Build UnitTests Homepage Fast algorithms to compute an approximation of the minimal volume oriented bounding box of a point cloud in

Gabriel Nützi 390 Dec 31, 2022
Autonomous Perception: 3D Object Detection with Complex-YOLO

Autonomous Perception: 3D Object Detection with Complex-YOLO LiDAR object detect

Thomas Dunlap 2 Feb 18, 2022
For IBM Quantum Challenge Africa 2021, 9 September (07:00 UTC) - 20 September (23:00 UTC).

IBM Quantum Challenge Africa 2021 To ensure Africa is able to apply quantum computing to solve problems relevant to the continent, the IBM Research La

Qiskit Community 48 Dec 25, 2022
How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Training a NN to 99% accuracy on MNIST in 0.76 seconds A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer

Tuomas Oikarinen 42 Dec 10, 2022
Roach: End-to-End Urban Driving by Imitating a Reinforcement Learning Coach

CARLA-Roach This is the official code release of the paper End-to-End Urban Driving by Imitating a Reinforcement Learning Coach by Zhejun Zhang, Alexa

Zhejun Zhang 118 Dec 28, 2022
GANSketchingJittor - Implementation of Sketch Your Own GAN in Jittor

GANSketching in Jittor Implementation of (Sketch Your Own GAN) in Jittor(计图). Or

Bernard Tan 10 Jul 02, 2022
MISSFormer: An Effective Medical Image Segmentation Transformer

MISSFormer Code for paper "MISSFormer: An Effective Medical Image Segmentation Transformer". Please read our preprint at the following link: paper_add

Fong 22 Dec 24, 2022
Pytorch implementation for "Open Compound Domain Adaptation" (CVPR 2020 ORAL)

Open Compound Domain Adaptation [Project] [Paper] [Demo] [Blog] Overview Open Compound Domain Adaptation (OCDA) is the author's re-implementation of t

Zhongqi Miao 137 Dec 15, 2022