DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. (2, 2) and during training its strides will be optimized for the task at hand.

We describe DiffStride in our ICLR 2022 paper Learning Strides in Convolutional Neural Network. Compared to the experiments described in the paper, this implementation uses a Pre-Act Resnet and uses Mixup in training.

Installation

To install the diffstride library, run the following pip git clone this repo:

git clone https://github.com/google-research/diffstride.git

The cd into the root and run the command:

pip install -e .

Example training

To run an example training on CIFAR10 and save the result in TensorBoard:

python3 -m diffstride.examples.main \
  --gin_config=cifar10.gin \
  --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'"

Using custom parameters

This implementation uses Gin to parametrize the model, data processing and training loop. To use custom parameters, one should edit examples/cifar10.gin.

For example, to train with SpectralPooling on cifar100:

data.load_datasets:
  name = 'cifar100'

resnet.Resnet:
  pooling_cls = @pooling.FixedSpectralPooling

Or to train with strided convolutions and without Mixup:

data.load_datasets:
  mixup_alpha = 0.0

resnet.Resnet:
  pooling_cls = None

Results

This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (resnet.resnet18.strides = '1, 1, 2, 2, 2') and with a 'poor' choice of strides (resnet.resnet18.strides = '1, 1, 3, 2, 3'). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization.

CIFAR-10

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 91.06 ± 0.04 89.21 ± 0.27
Spectral Pooling 93.49 ± 0.05 92.00 ± 0.08
DiffStride 94.20 ± 0.06 94.19 ± 0.15

CIFAR-100

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 65.75 ± 0.39 60.82 ± 0.42
Spectral Pooling 72.86 ± 0.23 67.74 ± 0.43
DiffStride 76.08 ± 0.23 76.09 ± 0.06

CPU/GPU Warning

We rely on the tensorflow FFT implementation which requires the input data to be in the channels_first format. This is usually not the regular data format of most datasets (including CIFAR) and running with channels_first also prevents from using of convolutions on CPU. Therefore even if we do support channels_last data format for CPU compatibility , we do encourage the user to run with channels_first data format on GPU.

Reference

If you use this repository, please consider citing:

@article{riad2022diffstride,
  title={Learning Strides in Convolutional Neural Networks},
  author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil},
  journal={ICLR},
  year={2022}
}

Disclainer

This is not an official Google product.

Owner
Google Research
Google Research
Awesome Human Pose Estimation

Human Pose Estimation Related Publication

Zhe Wang 1.2k Dec 26, 2022
Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introdu

OATML 360 Dec 28, 2022
This repository contains the database and code used in the paper Embedding Arithmetic for Text-driven Image Transformation

This repository contains the database and code used in the paper Embedding Arithmetic for Text-driven Image Transformation (Guillaume Couairon, Holger

Meta Research 31 Oct 17, 2022
Embracing Single Stride 3D Object Detector with Sparse Transformer

SST: Single-stride Sparse Transformer This is the official implementation of paper: Embracing Single Stride 3D Object Detector with Sparse Transformer

TuSimple 385 Dec 28, 2022
basic tutorial on pytorch

Quick Tutorial on PyTorch PyTorch Basics Linear Regression Logistic Regression Artificial Neural Networks Convolutional Neural Networks Recurrent Neur

7 Sep 15, 2022
Python library for science observations from the James Webb Space Telescope

JWST Calibration Pipeline JWST requires Python 3.7 or above and a C compiler for dependencies. Linux and MacOS platforms are tested and supported. Win

Space Telescope Science Institute 386 Dec 30, 2022
Object tracking and object detection is applied to track golf puts in real time and display stats/games.

Putting_Game Object tracking and object detection is applied to track golf puts in real time and display stats/games. Works best with the Perfect Prac

Max 1 Dec 29, 2021
From a body shape, infer the anatomic skeleton.

OSSO: Obtaining Skeletal Shape from Outside (CVPR 2022) This repository contains the official implementation of the skeleton inference from: OSSO: Obt

Marilyn Keller 166 Dec 28, 2022
PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement.

DECOR-GAN PyTorch 1.5 implementation for paper DECOR-GAN: 3D Shape Detailization by Conditional Refinement, Zhiqin Chen, Vladimir G. Kim, Matthew Fish

Zhiqin Chen 72 Dec 31, 2022
Using Self-Supervised Pretext Tasks for Active Learning - Official Pytorch Implementation

Using Self-Supervised Pretext Tasks for Active Learning - Official Pytorch Implementation Experiment Setting: CIFAR10 (downloaded and saved in ./DATA

John Seon Keun Yi 38 Dec 27, 2022
Simple machine learning library / 簡單易用的機器學習套件

FukuML Simple machine learning library / 簡單易用的機器學習套件 Installation $ pip install FukuML Tutorial Lesson 1: Perceptron Binary Classification Learning Al

Fukuball Lin 279 Sep 15, 2022
Fully convolutional networks for semantic segmentation

FCN-semantic-segmentation Simple end-to-end semantic segmentation using fully convolutional networks [1]. Takes a pretrained 34-layer ResNet [2], remo

Kai Arulkumaran 186 Dec 25, 2022
The source codes for TME-BNA: Temporal Motif-Preserving Network Embedding with Bicomponent Neighbor Aggregation.

TME The source codes for TME-BNA: Temporal Motif-Preserving Network Embedding with Bicomponent Neighbor Aggregation. Our implementation is based on TG

2 Feb 10, 2022
The official implementation of You Only Compress Once: Towards Effective and Elastic BERT Compression via Exploit-Explore Stochastic Nature Gradient.

You Only Compress Once: Towards Effective and Elastic BERT Compression via Exploit-Explore Stochastic Nature Gradient (paper) @misc{zhang2021compress,

46 Dec 07, 2022
PyTorch implementation of the wavelet analysis from Torrence & Compo

Continuous Wavelet Transforms in PyTorch This is a PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). The co

Tom Runia 262 Dec 21, 2022
Unofficial reimplementation of ECAPA-TDNN for speaker recognition (EER=0.86 for Vox1_O when train only in Vox2)

Introduction This repository contains my unofficial reimplementation of the standard ECAPA-TDNN, which is the speaker recognition in VoxCeleb2 dataset

Tao Ruijie 277 Dec 31, 2022
Arbitrary Distribution Modeling with Censorship in Real Time 59 2 60 3 Bidding Advertising for KDD'21

Arbitrary_Distribution_Modeling This repo implements the Neighborhood Likelihood Loss (NLL) and Arbitrary Distribution Modeling (ADM, with Interacting

7 Jan 03, 2023
PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending"

Bridging the Visual Gap: Wide-Range Image Blending PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending".

Chia-Ni Lu 69 Dec 20, 2022
Working demo of the Multi-class and Anomaly classification model using the CLIP feature space

👁️ Hindsight AI: Crime Classification With Clip About For Educational Purposes Only This is a recursive neural net trained to classify specific crime

Miles Tweed 2 Jun 05, 2022
Graph Self-Supervised Learning for Optoelectronic Properties of Organic Semiconductors

SSL_OSC Graph Self-Supervised Learning for Optoelectronic Properties of Organic Semiconductors

zaixizhang 2 May 14, 2022