Details about the wide minima density hypothesis and metrics to compute width of a minima

Overview

wide-minima-density-hypothesis

Details about the wide minima density hypothesis and metrics to compute width of a minima

This repo presents the wide minima density hypothesis as proposed in the following paper:

Key contributions:

  • Hypothesis about minima density
  • A SOTA LR schedule that exploits the hypothesis and beats general baseline schedules
  • Reducing wall clock training time and saving GPU compute hours with our LR schedule (Pretraining BERT-Large in 33% less training steps)
  • SOTA BLEU score on IWSLT'14 ( DE-EN )

Prerequisite:

  • CUDA, cudnn
  • Python 3.6+
  • PyTorch 1.4.0

Knee LR Schedule

Based on the density of wide vs narrow minima , we propose the Knee LR schedule that pushes generalization boundaries further by exploiting the nature of the loss landscape. The LR schedule is an explore-exploit based schedule, where the explore phase maintains a high lr for a significant time to access and land into a wide minimum with a good probability. The exploit phase is a simple linear decay scheme, which decays the lr to zero over the exploit phase. The only hyperparameter to tune is the explore epochs/steps. We have shown that 50% of the training budget allocated for explore is good enough for landing in a wider minimum and better generalization, thus removing the need for hyperparameter tuning.

  • Note that many experiments require warmup, which is done in the initial phase of training for a fixed number of steps and is usually required for Adam based optimizers/ large batch training. It is complementary with the Knee schedule and can be added to it.

To use the Knee Schedule, import the scheduler into your training file:

>>> from knee_lr_schedule import KneeLRScheduler
>>> scheduler = KneeLRScheduler(optimizer, peak_lr, warmup_steps, explore_steps, total_steps)

To use it during training :

>>> model.train()
>>> output = model(inputs)
>>> loss = criterion(output, targets)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step()

Details about args:

  • optimizer: optimizer needed for training the model ( SGD/Adam )
  • peak_lr: the peak learning required for explore phase to escape narrow minimas
  • warmup_steps: steps required for warmup( usually needed for adam optimizers/ large batch training) Default value: 0
  • explore_steps: total steps for explore phase.
  • total_steps: total training budget steps for training the model

Measuring width of a minima

Keskar et.al 2016 (https://arxiv.org/abs/1609.04836) argue that wider minima generalize much better than sharper minima. The computation method in their work uses the compute expensive LBFGS-B second order method, which is hard to scale. We use a projected gradient ascent based method, which is first order in nature and very easy to implement/use. Here is a simple way you can compute the width of the minima your model finds during training:

>>> from minima_width_compute import ComputeKeskarSharpness
>>> cks = ComputeKeskarSharpness(model_final_ckpt, optimizer, criterion, trainloader, epsilon, lr, max_steps)
>>> width = cks.compute_sharpness()

Details about args:

  • model_final_ckpt: model loaded with the saved checkpoint after final training step
  • optimizer : optimizer to use for projected gradient ascent ( SGD, Adam )
  • criterion : criterion for computing loss (e.g. torch.nn.CrossEntropyLoss)
  • trainloader : iterator over the training dataset (torch.utils.data.DataLoader)
  • epsilon : epsilon value determines the local boundary around which minima witdh is computed (Default value : 1e-4)
  • lr : lr for the optimizer to perform projected gradient ascent ( Default: 0.001)
  • max_steps : max steps to compute the width (Default: 1000). Setting it too low could lead to the gradient ascent method not converging to an optimal point.

The above default values have been chosen after tuning and observing the loss values of projected gradient ascent on Cifar-10 with ResNet-18 and SGD-Momentum optimizer, as mentioned in our paper. The values may vary for experiments with other optimizers/datasets/models. Please tune them for optimal convergence.

  • Acknowledgements: We would like to thank Harshay Shah (https://github.com/harshays) for his helpful discussions for computing the width of the minima.

Citation

Please cite our paper in your publications if you use our work:

@article{iyer2020wideminima,
  title={Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule},
  author={Iyer, Nikhil and Thejas, V and Kwatra, Nipun and Ramjee, Ramachandran and Sivathanu, Muthian},
  journal={arXiv preprint arXiv:2003.03977},
  year={2020}
}
  • Note: This work was done during an internship at Microsoft Research India
Owner
Nikhil Iyer
Studied at BITS-Pilani Hyderabad Campus. AI Research @ Jio-Haptik. Ex Microsoft Research India
Nikhil Iyer
Official repository for Fourier model that can generate periodic signals

Conditional Generation of Periodic Signals with Fourier-Based Decoder Jiyoung Lee, Wonjae Kim, Daehoon Gwak, Edward Choi This repository provides offi

8 May 25, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
Решения, подсказки, тесты и утилиты для тренировки по алгоритмам от Яндекса.

Решения и подсказки к тренировке по алгоритмам от Яндекса Что есть внутри Решения с подсказками и комментариями; рекомендую сначала смотреть md файл п

Yankovsky Andrey 50 Dec 26, 2022
YoHa - A practical hand tracking engine.

YoHa - A practical hand tracking engine.

2k Jan 06, 2023
Spatial Contrastive Learning for Few-Shot Classification (SCL)

This repo contains the official implementation of Spatial Contrastive Learning for Few-Shot Classification (SCL), which presents of a novel contrastive learning method applied to few-shot image class

Yassine 34 Dec 25, 2022
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Score-Based Generative Modeling through Stochastic Differential Equations This repo contains a PyTorch implementation for the paper Score-Based Genera

Yang Song 757 Jan 04, 2023
Self-labelling via simultaneous clustering and representation learning. (ICLR 2020)

Self-labelling via simultaneous clustering and representation learning 🆗 🆗 🎉 NEW models (20th August 2020): Added standard SeLa pretrained torchvis

Yuki M. Asano 469 Jan 02, 2023
Sequence modeling benchmarks and temporal convolutional networks

Sequence Modeling Benchmarks and Temporal Convolutional Networks (TCN) This repository contains the experiments done in the work An Empirical Evaluati

CMU Locus Lab 3.5k Jan 01, 2023
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
This code implements constituency parse tree aggregation

README This code implements constituency parse tree aggregation. Folder details code: This folder contains the code that implements constituency parse

Adithya Kulkarni 0 Oct 11, 2021
[SIGGRAPH 2020] Attribute2Font: Creating Fonts You Want From Attributes

Attr2Font Introduction This is the official PyTorch implementation of the Attribute2Font: Creating Fonts You Want From Attributes. Paper: arXiv | Rese

Yue Gao 200 Dec 15, 2022
Patch Rotation: A Self-Supervised Auxiliary Task for Robustness and Accuracy of Supervised Models

Patch-Rotation(PatchRot) Patch Rotation: A Self-Supervised Auxiliary Task for Robustness and Accuracy of Supervised Models Submitted to Neurips2021 To

4 Jul 12, 2021
This is an official implementation for "SimMIM: A Simple Framework for Masked Image Modeling".

Project This repo has been populated by an initial template to help get you started. Please make sure to update the content to build a great experienc

Microsoft 674 Dec 26, 2022
Learning Neural Network Subspaces

Learning Neural Network Subspaces Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin,

Apple 117 Nov 17, 2022
Speeding-Up Back-Propagation in DNN: Approximate Outer Product with Memory

Approximate Outer Product Gradient Descent with Memory Code for the numerical experiment of the paper Speeding-Up Back-Propagation in DNN: Approximate

2 Mar 02, 2022
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022
Official PyTorch implementation of Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval.

Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval PyTorch This is the PyTorch implementation of Retrieve in Style: Unsupervised Fa

60 Oct 12, 2022
Code for Temporally Abstract Partial Models

Code for Temporally Abstract Partial Models Accompanies the code for the experimental section of the paper: Temporally Abstract Partial Models, Khetar

DeepMind 19 Jul 13, 2022
Self-supervised Deep LiDAR Odometry for Robotic Applications

DeLORA: Self-supervised Deep LiDAR Odometry for Robotic Applications Overview Paper: link Video: link ICRA Presentation: link This is the correspondin

Robotic Systems Lab - Legged Robotics at ETH Zürich 181 Dec 29, 2022