Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

Overview

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

This repository is official Tensorflow implementation of paper:

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning [paper link]

and Tensorflow 2 example code for
   "Custom layers", "Custom training loop", "XLA (JIT)-compiling", "Distributed learing", and "Gradients accumulator".

Paper abstract

Conventional NAS-based pruning algorithms aim to find the sub-network with the best validation performance. However, validation performance does not successfully represent test performance, i.e., potential performance. Also, although fine-tuning the pruned network to restore the performance drop is an inevitable process, few studies have handled this issue. This paper proposes a novel sub-network search and fine-tuning method, i.e., Ensemble Knowledge Guidance (EKG). First, we experimentally prove that the fluctuation of the loss landscape is an effective metric to evaluate the potential performance. In order to search a sub-network with the smoothest loss landscape at a low cost, we propose a pseudo-supernet built by an ensemble sub-network knowledge distillation. Next, we propose a novel fine-tuning that re-uses the information of the search phase. We store the interim sub-networks, that is, the by-products of the search phase, and transfer their knowledge into the pruned network. Note that EKG is easy to be plugged-in and computationally efficient. For example, in the case of ResNet-50, about 45% of FLOPS is removed without any performance drop in only 315 GPU hours.


Conceptual visualization of the goal of the proposed method.

Contribution points and key features

  • As a new tool to measure the potential performance of sub-network in NAS-based pruning, the smoothness of the loss landscape is presented. Also, the experimental evidence that the loss landscape fluctuation has a higher correlation with the test performance than the validation performance is provided.
  • The pseudo-supernet based on an ensemble sub-network knowledge distillation is proposed to find a sub-network of smoother loss landscape without increasing complexity. It helps NAS-based pruning to prune all pre-trained networks, and also allows to find optimal sub-network(s) more accurately.
  • To our knowledge, this paper provides the world-first approach to store the information of the search phase in a memory bank and to reuse it in the fine-tuning phase of the pruned network. The proposed memory bank contributes to greatly improving the performance of the pruned network.

Requirement

  • Tensorflow >= 2.7 (I have tested on 2.7-2.8)
  • Pickle
  • tqdm

How to run

  1. Move to the codebase.
  2. Train and evaluate our model by the below command.
  # ResNet-56 on CIFAR10
  python train_cifar.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --search_target_rate 0.45 --train_path ../test
  python test.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --trained_param ../test/trained_param.pkl

Experimental results


(Left) Potential performance vs. validation loss (right) Potential performance vs. condition number. 50 sub-networks of ResNet-56 trained on CIFAR10 were used for this experiment. accurately.


Visualization of loss landscapes of sub-networks searched by various filter importance scoring algorithms.

Comparison with various pruning techniques for ResNet family trained on ImageNet.


Performance analysis in case of ResNet-50 trained on ImageNet-2012. The left plot is the FLOPs reduction rate-Top-1 accuracy, and the right plot is the GPU hours-Top-1 accuracy.

Reference

@article{lee2022ensemble,
  title        = {Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning},
  author       = {Seunghyun Lee, Byung Cheol Song},
  year         = 2022,
  journal      = {arXiv preprint arXiv:2203.02651}
}

Owner
Seunghyun Lee
Knowledge distillation; Neural network light-weighting; Tensorflow
Seunghyun Lee
Official implementation of "StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation" (SIGGRAPH 2021)

StyleCariGAN: Caricature Generation via StyleGAN Feature Map Modulation This repository contains the official PyTorch implementation of the following

Wonjong Jang 270 Dec 30, 2022
Tensorflow implementation of ID-Unet: Iterative Soft and Hard Deformation for View Synthesis.

ID-Unet: Iterative-view-synthesis(CVPR2021 Oral) Tensorflow implementation of ID-Unet: Iterative Soft and Hard Deformation for View Synthesis. Overvie

17 Aug 23, 2022
Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation

DistMIS Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation. DistriMIS Distributing Deep Learning Hyperparameter Tuning

HiEST 2 Sep 09, 2022
Generative Adversarial Text to Image Synthesis

Text To Image Synthesis This is a tensorflow implementation of synthesizing images. The images are synthesized using the GAN-CLS Algorithm from the pa

Hao 575 Jan 08, 2023
Use graph-based analysis to re-classify stocks and to improve Markowitz portfolio optimization

Dynamic Stock Industrial Classification Use graph-based analysis to re-classify stocks and experiment different re-classification methodologies to imp

Sheng Yang 10 Dec 05, 2022
code for `Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation`

Look Closer to Segment Better: Boundary Patch Refinement for Instance Segmentation (CVPR 2021) Introduction PBR is a conceptually simple yet effective

H.Chen 143 Jan 05, 2023
The repository contain code for building compiler using puthon.

Building Compiler This is a python implementation of JamieBuild's "Super Tiny Compiler" Overview JamieBuilds developed a wonderfully educative compile

Shyam Das Shrestha 1 Nov 21, 2021
Space Invaders For Python

Space-Invaders Just download or clone the git repository. To run the Space Invader game you need to have pyhton installed in you system. If you dont h

Fei 5 Jul 27, 2022
The codebase for our paper "Generative Occupancy Fields for 3D Surface-Aware Image Synthesis" (NeurIPS 2021)

Generative Occupancy Fields for 3D Surface-Aware Image Synthesis (NeurIPS 2021) Project Page | Paper Xudong Xu, Xingang Pan, Dahua Lin and Bo Dai GOF

xuxudong 97 Nov 10, 2022
Experiments on continual learning from a stream of pretrained models.

Ex-model CL Ex-model continual learning is a setting where a stream of experts (i.e. model's parameters) is available and a CL model learns from them

Antonio Carta 6 Dec 04, 2022
FocusFace: Multi-task Contrastive Learning for Masked Face Recognition

FocusFace This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at IEEE International C

Pedro Neto 21 Nov 17, 2022
A vision library for performing sliced inference on large images/small objects

SAHI: Slicing Aided Hyper Inference A vision library for performing sliced inference on large images/small objects Overview Object detection and insta

Open Business Software Solutions 2.3k Jan 04, 2023
Tensorflow implementation of the paper "HumanGPS: Geodesic PreServing Feature for Dense Human Correspondences", CVPR 2021.

HumanGPS: Geodesic PreServing Feature for Dense Human Correspondences Tensorflow implementation of the paper "HumanGPS: Geodesic PreServing Feature fo

Google Interns 50 Dec 21, 2022
Algorithmic trading with deep learning experiments

Deep-Trading Algorithmic trading with deep learning experiments. Now released part one - simple time series forecasting. I plan to implement more soph

Alex Honchar 1.4k Jan 02, 2023
Pytorch implementation of our paper under review — Lottery Jackpots Exist in Pre-trained Models

Lottery Jackpots Exist in Pre-trained Models (Paper Link) Requirements Python = 3.7.4 Pytorch = 1.6.1 Torchvision = 0.4.1 Reproduce the Experiment

Yuxin Zhang 27 Jun 28, 2022
LiDAR R-CNN: An Efficient and Universal 3D Object Detector

LiDAR R-CNN: An Efficient and Universal 3D Object Detector Introduction This is the official code of LiDAR R-CNN: An Efficient and Universal 3D Object

TuSimple 295 Jan 05, 2023
This is the pytorch implementation for the paper: *Learning Accurate Performance Predictors for Ultrafast Automated Model Compression*, which is in submission to TPAMI

SeerNet This is the pytorch implementation for the paper: Learning Accurate Performance Predictors for Ultrafast Automated Model Compression, which is

3 May 01, 2022
Implementation of PersonaGPT Dialog Model

PersonaGPT An open-domain conversational agent with many personalities PersonaGPT is an open-domain conversational agent cpable of decoding personaliz

ILLIDAN Lab 42 Jan 01, 2023
Code for MarioNette: Self-Supervised Sprite Learning, in NeurIPS 2021

MarioNette | Webpage | Paper | Video MarioNette: Self-Supervised Sprite Learning Dmitriy Smirnov, Michaël Gharbi, Matthew Fisher, Vitor Guizilini, Ale

Dima Smirnov 28 Nov 18, 2022
Cross-Task Consistency Learning Framework for Multi-Task Learning

Cross-Task Consistency Learning Framework for Multi-Task Learning Tested on numpy(v1.19.1) opencv-python(v4.4.0.42) torch(v1.7.0) torchvision(v0.8.0)

Aki Nakano 2 Jan 08, 2022