DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. (2, 2) and during training its strides will be optimized for the task at hand.

We describe DiffStride in our ICLR 2022 paper Learning Strides in Convolutional Neural Network. Compared to the experiments described in the paper, this implementation uses a Pre-Act Resnet and uses Mixup in training.

Installation

To install the diffstride library, run the following pip git clone this repo:

git clone https://github.com/google-research/diffstride.git

The cd into the root and run the command:

pip install -e .

Example training

To run an example training on CIFAR10 and save the result in TensorBoard:

python3 -m diffstride.examples.main \
  --gin_config=cifar10.gin \
  --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'"

Using custom parameters

This implementation uses Gin to parametrize the model, data processing and training loop. To use custom parameters, one should edit examples/cifar10.gin.

For example, to train with SpectralPooling on cifar100:

data.load_datasets:
  name = 'cifar100'

resnet.Resnet:
  pooling_cls = @pooling.FixedSpectralPooling

Or to train with strided convolutions and without Mixup:

data.load_datasets:
  mixup_alpha = 0.0

resnet.Resnet:
  pooling_cls = None

Results

This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (resnet.resnet18.strides = '1, 1, 2, 2, 2') and with a 'poor' choice of strides (resnet.resnet18.strides = '1, 1, 3, 2, 3'). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization.

CIFAR-10

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 91.06 ± 0.04 89.21 ± 0.27
Spectral Pooling 93.49 ± 0.05 92.00 ± 0.08
DiffStride 94.20 ± 0.06 94.19 ± 0.15

CIFAR-100

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 65.75 ± 0.39 60.82 ± 0.42
Spectral Pooling 72.86 ± 0.23 67.74 ± 0.43
DiffStride 76.08 ± 0.23 76.09 ± 0.06

CPU/GPU Warning

We rely on the tensorflow FFT implementation which requires the input data to be in the channels_first format. This is usually not the regular data format of most datasets (including CIFAR) and running with channels_first also prevents from using of convolutions on CPU. Therefore even if we do support channels_last data format for CPU compatibility , we do encourage the user to run with channels_first data format on GPU.

Reference

If you use this repository, please consider citing:

@article{riad2022diffstride,
  title={Learning Strides in Convolutional Neural Networks},
  author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil},
  journal={ICLR},
  year={2022}
}

Disclainer

This is not an official Google product.

Owner
Google Research
Google Research
FAMIE is a comprehensive and efficient active learning (AL) toolkit for multilingual information extraction (IE)

FAMIE: A Fast Active Learning Framework for Multilingual Information Extraction

18 Sep 01, 2022
sequitur is a library that lets you create and train an autoencoder for sequential data in just two lines of code

sequitur sequitur is a library that lets you create and train an autoencoder for sequential data in just two lines of code. It implements three differ

Jonathan Shobrook 305 Dec 21, 2022
Image-Scaling Attacks and Defenses

Image-Scaling Attacks & Defenses This repository belongs to our publication: Erwin Quiring, David Klein, Daniel Arp, Martin Johns and Konrad Rieck. Ad

Erwin Quiring 163 Nov 21, 2022
PyTorch implementation of residual gated graph ConvNets, ICLR’18

Residual Gated Graph ConvNets April 24, 2018 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbress

Xavier Bresson 112 Aug 10, 2022
Pytorch implementation for the paper: Contrastive Learning for Cold-start Recommendation

Contrastive Learning for Cold-start Recommendation This is our Pytorch implementation for the paper: Yinwei Wei, Xiang Wang, Qi Li, Liqiang Nie, Yan L

45 Dec 13, 2022
🔥RandLA-Net in Tensorflow (CVPR 2020, Oral & IEEE TPAMI 2021)

RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds (CVPR 2020) This is the official implementation of RandLA-Net (CVPR2020, Oral

Qingyong 1k Dec 30, 2022
Out-of-boundary View Synthesis towards Full-frame Video Stabilization

Out-of-boundary View Synthesis towards Full-frame Video Stabilization Introduction | Update | Results Demo | Introduction This repository contains the

25 Oct 10, 2022
Official repository for "Action-Based Conversations Dataset: A Corpus for Building More In-Depth Task-Oriented Dialogue Systems"

Action-Based Conversations Dataset (ABCD) This respository contains the code and data for ABCD (Chen et al., 2021) Introduction Whereas existing goal-

ASAPP Research 49 Oct 09, 2022
Live training loss plot in Jupyter Notebook for Keras, PyTorch and others

livelossplot Don't train deep learning models blindfolded! Be impatient and look at each epoch of your training! (RECENT CHANGES, EXAMPLES IN COLAB, A

Piotr Migdał 1.2k Jan 08, 2023
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 01, 2023
Tensorflow Repo for "DeepGCNs: Can GCNs Go as Deep as CNNs?"

DeepGCNs: Can GCNs Go as Deep as CNNs? In this work, we present new ways to successfully train very deep GCNs. We borrow concepts from CNNs, mainly re

Guohao Li 612 Nov 15, 2022
List of all dependencies affected by node-ipc malicious commit

node-ipc-dependencies-list List of all dependencies affected by node-ipc malicious commit as of 17/3/2022 - 19/3/2022 (timestamp) Please improve upon

99 Oct 15, 2022
Iowa Project - My second project done at General Assembly, focused on feature engineering and understanding Linear Regression as a concept

Project 2 - Ames Housing Data and Kaggle Challenge PROBLEM STATEMENT Inferring or Predicting? What's more valuable for a housing model? When creating

Adam Muhammad Klesc 1 Jan 03, 2022
Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv 官方原版代

Hongtao Wen 4 Jan 15, 2022
The implemention of Video Depth Estimation by Fusing Flow-to-Depth Proposals

Flow-to-depth (FDNet) video-depth-estimation This is the implementation of paper Video Depth Estimation by Fusing Flow-to-Depth Proposals Jiaxin Xie,

32 Jun 14, 2022
CaLiGraph Ontology as a Challenge for Semantic Reasoners ([email protected]'21)

CaLiGraph for Semantic Reasoning Evaluation Challenge This repository contains code and data to use CaLiGraph as a benchmark dataset in the Semantic R

Nico Heist 0 Jun 08, 2022
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
Simple tutorials using Google's TensorFlow Framework

TensorFlow-Tutorials Introduction to deep learning based on Google's TensorFlow framework. These tutorials are direct ports of Newmu's Theano Tutorial

Nathan Lintz 6k Jan 06, 2023
Enigma-Plus - Python based Enigma machine simulator with some extra features

Enigma-Plus Python based Enigma machine simulator with some extra features Examp

1 Jan 05, 2022
Not All Points Are Equal: Learning Highly Efficient Point-based Detectors for 3D LiDAR Point Clouds (CVPR 2022, Oral)

Not All Points Are Equal: Learning Highly Efficient Point-based Detectors for 3D LiDAR Point Clouds (CVPR 2022, Oral) This is the official implementat

Yifan Zhang 259 Dec 25, 2022