PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".

Related tags

Deep LearningSAQ
Overview

Sharpness-aware Quantization for Deep Neural Networks

License

Recent Update

2021.11.23: We release the source code of SAQ.

Setup the environments

  1. Clone the repository locally:
git clone https://github.com/zhuang-group/SAQ
  1. Install pytorch 1.8+, tensorboard and prettytable
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install tensorboard
pip install prettytable

Data preparation

ImageNet

  1. Download the ImageNet 2012 dataset from here, and prepare the dataset based on this script.

  2. Change the dataset path in link_imagenet.py and link the ImageNet-100 by

python link_imagenet.py

CIFAR-100

Download the CIFAR-100 dataset from here.

After downloading ImageNet and CIFAR-100, the file structure should look like:

dataset
├── imagenet
    ├── train
    │   ├── class1
    │   │   ├── img1.jpeg
    │   │   ├── img2.jpeg
    │   │   └── ...
    │   ├── class2
    │   │   ├── img3.jpeg
    │   │   └── ...
    │   └── ...
    └── val
        ├── class1
        │   ├── img4.jpeg
        │   ├── img5.jpeg
        │   └── ...
        ├── class2
        │   ├── img6.jpeg
        │   └── ...
        └── ...
├── cifar100
    ├── cifar-100-python
    │   ├── meta
    │   ├── test
    │   ├── train
    │   └── ...
    └── ...

Training

Fixed-precision quantization

  1. Download the pre-trained full-precision models from the model zoo.

  2. Train low-precision models.

To train low-precision ResNet-20 on CIFAR-100, run:

sh script/train_qsam_cifar_r20.sh

To train low-precision ResNet-18 on ImageNet, run:

sh script/train_qsam_imagenet_r18.sh

Mixed-precision quantization

  1. Download the pre-trained full-precision models from the model zoo.

  2. Train the configuration generator.

To train the configuration generator of ResNet-20 on CIFAR-100, run:

sh script/train_generator_cifar_r20.sh

To train the configuration generator on ImageNet, run:

sh script/train_generator_imagenet_r18.sh
  1. After training the configuration generator, run following commands to fine-tune the resulting models with the obtained bitwidth configurations on CIFAR-100 and ImageNet.
sh script/finetune_cifar_r20.sh
sh script/finetune_imagenet_r18.sh

Results on CIFAR-100

Network Method Bitwidth BOPs (M) Top-1 Acc. (%) Top-5 Acc. (%)
ResNet-20 SAQ 4 674.6 68.7 91.2
ResNet-20 SAMQ MP 659.3 68.7 91.2
ResNet-20 SAQ 3 392.1 67.7 90.8
ResNet-20 SAMQ MP 374.4 68.6 91.2
MobileNetV2 SAQ 4 1508.9 75.6 93.7
MobileNetV2 SAMQ MP 1482.1 75.5 93.6
MobileNetV2 SAQ 3 877.1 74.4 93.2
MobileNetV2 SAMQ MP 869.5 75.5 93.7

Results on ImageNet

Network Method Bitwidth BOPs (G) Top-1 Acc. (%) Top-5 Acc. (%)
ResNet-18 SAQ 4 34.7 71.3 90.0
ResNet-18 SAMQ MP 33.7 71.4 89.9
ResNet-18 SAQ 2 14.4 67.1 87.3
MobileNetV2 SAQ 4 5.3 70.2 89.4
MobileNetV2 SAMQ MP 5.3 70.3 89.4

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

This repository has adopted codes from SAM, ASAM and ESAM, we thank the authors for their open-sourced code.

You might also like...
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

A bunch of random PyTorch models using PyTorch's C++ frontend
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Comments
  • Quantize_first_last_layer

    Quantize_first_last_layer

    Hi! I noticed that in your code, you set bits_weights=8 and bits_activations=32 for first layer as default, it's not what is claimed in your paper " For the first and last layers of all quantized models, we quantize both weights and activations to 8-bit. " And I see an accuracy drop if I adjust the bits_activations to 8 for the first layer, could u please explain what is the reason? Thanks!

    opened by mmmiiinnnggg 0
  • 代码问题请求帮助

    代码问题请求帮助

    你好,带佬的代码写的很好,有部分代码不太懂,想请教一下, parser.add_argument( "--arch_bits", type=lambda s: [float(item) for item in s.split(",")] if len(s) != 0 else "", default=" ", help="bits configuration of each layer",

    if len(args.arch_bits) != 0: if args.wa_same_bit: set_wae_bits(model, args.arch_bits) elif args.search_w_bit: set_w_bits(model, args.arch_bits) else: set_bits(model, args.arch_bits) show_bits(model) logger.info("Set arch bits to: {}".format(args.arch_bits)) logger.info(model) 这个arch_bits主要是做什么的呢,卡在这里有段时间了

    opened by LKAMING97 0
Releases(v0.1.1)
Owner
Zhuang AI Group
Zhuang AI Group
This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationships.

Auto-Lambda This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationship

Shikun Liu 76 Dec 20, 2022
Motion and Shape Capture from Sparse Markers

MoSh++ This repository contains the official chumpy implementation of mocap body solver used for AMASS: AMASS: Archive of Motion Capture as Surface Sh

Nima Ghorbani 135 Dec 23, 2022
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

Emma Rocheteau 76 Dec 22, 2022
An educational AI robot based on NVIDIA Jetson Nano.

JetBot Looking for a quick way to get started with JetBot? Many third party kits are now available! JetBot is an open-source robot based on NVIDIA Jet

NVIDIA AI IOT 2.6k Dec 29, 2022
ICCV2021 - Mining Contextual Information Beyond Image for Semantic Segmentation

Introduction The official repository for "Mining Contextual Information Beyond Image for Semantic Segmentation". Our full code has been merged into ss

55 Nov 09, 2022
Video-Music Transformer

VMT Video-Music Transformer (VMT) is an attention-based multi-modal model, which generates piano music for a given video. Paper https://arxiv.org/abs/

Chin-Tung Lin 5 Jul 13, 2022
MEDS: Enhancing Memory Error Detection for Large-Scale Applications

MEDS: Enhancing Memory Error Detection for Large-Scale Applications Prerequisites cmake and clang Build MEDS supporting compiler $ make Build Using Do

Secomp Lab at Purdue University 34 Dec 14, 2022
A tool to analyze leveraged liquidity mining and find optimal option combination for hedging.

LP-Option-Hedging Description A Python program to analyze leveraged liquidity farming/mining and find the optimal option combination for hedging imper

Aureliano 18 Dec 19, 2022
Multimodal Temporal Context Network (MTCN)

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
Open source code for the paper of Neural Sparse Voxel Fields.

Neural Sparse Voxel Fields (NSVF) Project Page | Video | Paper | Data Photo-realistic free-viewpoint rendering of real-world scenes using classical co

Meta Research 647 Dec 27, 2022
LogAvgExp - Pytorch Implementation of LogAvgExp

LogAvgExp - Pytorch Implementation of LogAvgExp for Pytorch Install $ pip instal

Phil Wang 31 Oct 14, 2022
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork 👀 : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
Definition of a business problem according to Wilson Lower Bound Score and Time Based Average Rating

Wilson Lower Bound Score, Time Based Rating Average In this study I tried to calculate the product rating and sorting reviews more accurately. I have

3 Sep 30, 2021
HEAM: High-Efficiency Approximate Multiplier Optimization for Deep Neural Networks

Approximate Multiplier by HEAM What's HEAM? HEAM is a general optimization method to generate high-efficiency approximate multipliers for specific app

4 Sep 11, 2022
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

Samantha Van Seters 1 Jan 10, 2022
Image-generation-baseline - MUGE Text To Image Generation Baseline

MUGE Text To Image Generation Baseline Requirements and Installation More detail

23 Oct 17, 2022
Co-GAIL: Learning Diverse Strategies for Human-Robot Collaboration

CoGAIL Table of Content Overview Installation Dataset Training Evaluation Trained Checkpoints Acknowledgement Citations License Overview This reposito

Jeremy Wang 29 Dec 24, 2022
Practical tutorials and labs for TensorFlow used by Nvidia, FFN, CNN, RNN, Kaggle, AE

TensorFlow Tutorial - used by Nvidia Learn TensorFlow from scratch by examples and visualizations with interactive jupyter notebooks. Learn to compete

Alexander R Johansen 1.9k Dec 19, 2022
The comma.ai Calibration Challenge!

Welcome to the comma.ai Calibration Challenge! Your goal is to predict the direction of travel (in camera frame) from provided dashcam video. This rep

comma.ai 697 Jan 05, 2023
Our implementation used for the MICCAI 2021 FLARE Challenge titled 'Efficient Multi-Organ Segmentation Using SpatialConfiguartion-Net with Low GPU Memory Requirements'.

Efficient Multi-Organ Segmentation Using SpatialConfiguartion-Net with Low GPU Memory Requirements Our implementation used for the MICCAI 2021 FLARE C

Franz Thaler 3 Sep 27, 2022