Implementation of Continuous Sparsification, a method for pruning and ticket search in deep networks

Overview

PWC

Continuous Sparsification

Implementation of Continuous Sparsification (CS), a method based on l_0 regularization to find sparse neural networks, proposed in [Winning the Lottery with Continuous Sparsification].

Requirements

Python 2/3, PyTorch == 1.1.0

Training a ResNet on CIFAR with Continuous Sparsification

The main.py script can be used to train a ResNet-18 on CIFAR-10 with Continuous Sparsification. By default it will perform 3 rounds of training, each round consisting of 85 epochs. With the default hyperparameter values for the mask initialization, mask penalty, and final temperature, the method will find a sub-network with 20-30% sparsity which achieves 91.5-92.0% test accuracy when trained after rewinding (the dense network achieves 90-91%). The training and rewinding protocols follow the ones in the Lottery Ticket Hypothesis papers by Frankle.

In general, the sparsity of the final sub-network can be controlled by changing the value used to initialize the soft mask parameters. This can be done with, for example:

python main.py --mask-initial-value 0.1

The default value is 0.0 and increasing it will result in less sparse sub-networks. High sparsity sub-networks can be found by setting it to -0.1.

Extending the code

To train other network models with Continuous Sparsification, the first step is to choose which layers you want to sparsify and then implement PyTorch modules that perform soft masking on its original parameters. This repository contains code for 2D convolutions with soft masking: the SoftMaskedConv2d module in models/layers.py:

class SoftMaskedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, mask_initial_value=0.):
        super(SoftMaskedConv2d, self).__init__()
        self.mask_initial_value = mask_initial_value
        
        self.in_channels = in_channels
        self.out_channels = out_channels    
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        nn.init.xavier_normal_(self.weight)
        self.init_weight = nn.Parameter(torch.zeros_like(self.weight), requires_grad=False)
        self.init_mask()
        
    def init_mask(self):
        self.mask_weight = nn.Parameter(torch.Tensor(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        nn.init.constant_(self.mask_weight, self.mask_initial_value)

    def compute_mask(self, temp, ticket):
        scaling = 1. / sigmoid(self.mask_initial_value)
        if ticket: mask = (self.mask_weight > 0).float()
        else: mask = F.sigmoid(temp * self.mask_weight)
        return scaling * mask      
        
    def prune(self, temp):
        self.mask_weight.data = torch.clamp(temp * self.mask_weight.data, max=self.mask_initial_value)   

    def forward(self, x, temp=1, ticket=False):
        self.mask = self.compute_mask(temp, ticket)
        masked_weight = self.weight * self.mask
        out = F.conv2d(x, masked_weight, stride=self.stride, padding=self.padding)        
        return out
        
    def checkpoint(self):
        self.init_weight.data = self.weight.clone()       
        
    def rewind_weights(self):
        self.weight.data = self.init_weight.clone()

    def extra_repr(self):
        return '{}, {}, kernel_size={}, stride={}, padding={}'.format(
            self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)

Extending it to other layers is straightforward, since you only need to change the init, init_mask and the forward methods. In init_mask, you should create a mask parameter (of PyTorch Parameter type) for each parameter set that you want to sparsify -- each mask parameter must have the same dimensions as the corresponding parameter.

    def init_mask(self):
        self.mask_weight = nn.Parameter(torch.Tensor(...))
        nn.init.constant_(self.mask_weight, self.mask_initial_value)

In the forward method, you need to compute the masked parameter for each parameter to be sparsified (e.g. masked weights for a Linear layer), and then compute the output of the layer with the corresponding PyTorch functional call (e.g. F.Linear for Linear layers). For example:

    def forward(self, x, temp=1, ticket=False):
        self.mask = self.compute_mask(temp, ticket)
        masked_weight = self.weight * self.mask
        out = F.linear(x, masked_weight)        
        return out

Once all the required layers have been implemented, it remains to implement the network which CS will sparsify. In models/networks.py, you can find code for the ResNet-18 and use it as base to implement other networks. In general, your network can inherit from MaskedNet instead of nn.Module and most of the required functionalities will be immediately available. What remains is to use the layers you implemented (the ones with soft masked paramaters) in your network, and remember to pass temp and ticket as additional inputs: temp is the current temperature of CS (assumed to be the attribute model.temp in main.py), while ticket is a boolean variable that controls whether the parameters' masks should be soft (ticket=False) or hard (ticket=True). Having ticket=True means that the mask will be binary and the masked parameters will actually be sparse. Use ticket=False for training (i.e. sub-network search) and ticket=True once you are done and want to evaluate the sparse sub-network.

Future plans

We plan to make the effort of applying CS to other layers/networks considerably smaller. This will be hopefully achieved by offering a function that receives a standard PyTorch Module object and returns another Module but with the mask parameters properly created and the forward passes overloaded to use masked parameters instead.

If there are specific functionalities that would help you in your research or in applying our method in general, feel free to suggest it and we will consider implementing it.

Citation

If you use our method for research purposes, please cite our work:

@article{ssm2019cs,
       author = {Savarese, Pedro and Silva, Hugo and Maire, Michael},
        title = {Winning the Lottery with Continuous Sparsification},
      journal = {arXiv:1912.04427},
         year = "2019"
}
Owner
Pedro Savarese
PhD student at TTIC
Pedro Savarese
ROS-UGV-Control-Interface - Control interface which can be used in any UGV

ROS-UGV-Control-Interface Cam Closed: Cam Opened:

Ahmet Fatih Akcan 1 Nov 04, 2022
Multi-agent reinforcement learning algorithm and environment

Multi-agent reinforcement learning algorithm and environment [en/cn] Pytorch implements multi-agent reinforcement learning algorithms including IQL, Q

万鲲鹏 7 Sep 20, 2022
ICSS - Interactive Continual Semantic Segmentation

Presentation This repository contains the code of our paper: Weakly-supervised c

Alteia 9 Jul 23, 2022
Real-Time-Student-Attendence-System - Real Time Student Attendence System

Real-Time-Student-Attendence-System The Student Attendance Management System Pro

Rounak Das 1 Feb 15, 2022
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmenta

NVIDIA Research Projects 3.2k Dec 30, 2022
The first machine learning framework that encourages learning ML concepts instead of memorizing class functions.

SeaLion is designed to teach today's aspiring ml-engineers the popular machine learning concepts of today in a way that gives both intuition and ways of application. We do this through concise algori

Anish 324 Dec 27, 2022
ROS support for Velodyne 3D LIDARs

Overview Velodyne1 is a collection of ROS2 packages supporting Velodyne high definition 3D LIDARs3. Warning: The master branch normally contains code

ROS device drivers 543 Dec 30, 2022
This is the winning solution of the Endocv-2021 grand challange.

Endocv2021-winner [Paper] This is the winning solution of the Endocv-2021 grand challange. Dependencies pytorch # tested with 1.7 and 1.8 torchvision

Vajira Thambawita 14 Dec 03, 2022
This library provides an abstraction to perform Model Versioning using Weight & Biases.

Description This library provides an abstraction to perform Model Versioning using Weight & Biases. Features Version a new trained model Promote a mod

Hector Lopez Almazan 2 Jan 28, 2022
Video Frame Interpolation without Temporal Priors (a general method for blurry video interpolation)

Video Frame Interpolation without Temporal Priors (NeurIPS2020) [Paper] [video] How to run Prerequisites NVIDIA GPU + CUDA 9.0 + CuDNN 7.6.5 Pytorch 1

YoujianZhang 31 Sep 04, 2022
Mahadi-Now - This Is Pakistani Just Now Login Tools

PAKISTANI JUST NOW LOGIN TOOLS Install apt update apt upgrade apt install python

MAHADI HASAN AFRIDI 19 Apr 06, 2022
Code associated with the paper "Towards Understanding the Data Dependency of Mixup-style Training".

Mixup-Data-Dependency Code associated with the paper "Towards Understanding the Data Dependency of Mixup-style Training". Running Alternating Line Exp

Muthu Chidambaram 0 Nov 11, 2021
GrabGpu_py: a scripts for grab gpu when gpu is free

GrabGpu_py a scripts for grab gpu when gpu is free. WaitCondition: gpu_memory

tianyuluan 3 Jun 18, 2022
[ICCV 2021] Focal Frequency Loss for Image Reconstruction and Synthesis

Focal Frequency Loss - Official PyTorch Implementation This repository provides the official PyTorch implementation for the following paper: Focal Fre

Liming Jiang 460 Jan 04, 2023
Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition"

Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition", accepted at ACL 2021. For details of the model and experiments, please see our paper.

tricktreat 87 Dec 16, 2022
Python library containing BART query generation and BERT-based Siamese models for neural retrieval.

Neural Retrieval Embedding-based Zero-shot Retrieval through Query Generation leverages query synthesis over large corpuses of unlabeled text (such as

Amazon Web Services - Labs 35 Apr 14, 2022
Implementation of Kronecker Attention in Pytorch

Kronecker Attention Pytorch Implementation of Kronecker Attention in Pytorch. Results look less than stellar, but if someone found some context where

Phil Wang 16 May 06, 2022
Deep Reinforcement Learning with pytorch & visdom

Deep Reinforcement Learning with pytorch & visdom Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A

Jingwei Zhang 783 Jan 04, 2023
Python-kafka-reset-consumergroup-offset-example - Python Kafka reset consumergroup offset example

Python Kafka reset consumergroup offset example This is a simple example of how

Willi Carlsen 1 Feb 16, 2022
Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021)

Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021) Citation Please cite as: @inproceedings{liu2020understan

Sunbow Liu 22 Nov 25, 2022