Details about the wide minima density hypothesis and metrics to compute width of a minima

Overview

wide-minima-density-hypothesis

Details about the wide minima density hypothesis and metrics to compute width of a minima

This repo presents the wide minima density hypothesis as proposed in the following paper:

Key contributions:

  • Hypothesis about minima density
  • A SOTA LR schedule that exploits the hypothesis and beats general baseline schedules
  • Reducing wall clock training time and saving GPU compute hours with our LR schedule (Pretraining BERT-Large in 33% less training steps)
  • SOTA BLEU score on IWSLT'14 ( DE-EN )

Prerequisite:

  • CUDA, cudnn
  • Python 3.6+
  • PyTorch 1.4.0

Knee LR Schedule

Based on the density of wide vs narrow minima , we propose the Knee LR schedule that pushes generalization boundaries further by exploiting the nature of the loss landscape. The LR schedule is an explore-exploit based schedule, where the explore phase maintains a high lr for a significant time to access and land into a wide minimum with a good probability. The exploit phase is a simple linear decay scheme, which decays the lr to zero over the exploit phase. The only hyperparameter to tune is the explore epochs/steps. We have shown that 50% of the training budget allocated for explore is good enough for landing in a wider minimum and better generalization, thus removing the need for hyperparameter tuning.

  • Note that many experiments require warmup, which is done in the initial phase of training for a fixed number of steps and is usually required for Adam based optimizers/ large batch training. It is complementary with the Knee schedule and can be added to it.

To use the Knee Schedule, import the scheduler into your training file:

>>> from knee_lr_schedule import KneeLRScheduler
>>> scheduler = KneeLRScheduler(optimizer, peak_lr, warmup_steps, explore_steps, total_steps)

To use it during training :

>>> model.train()
>>> output = model(inputs)
>>> loss = criterion(output, targets)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step()

Details about args:

  • optimizer: optimizer needed for training the model ( SGD/Adam )
  • peak_lr: the peak learning required for explore phase to escape narrow minimas
  • warmup_steps: steps required for warmup( usually needed for adam optimizers/ large batch training) Default value: 0
  • explore_steps: total steps for explore phase.
  • total_steps: total training budget steps for training the model

Measuring width of a minima

Keskar et.al 2016 (https://arxiv.org/abs/1609.04836) argue that wider minima generalize much better than sharper minima. The computation method in their work uses the compute expensive LBFGS-B second order method, which is hard to scale. We use a projected gradient ascent based method, which is first order in nature and very easy to implement/use. Here is a simple way you can compute the width of the minima your model finds during training:

>>> from minima_width_compute import ComputeKeskarSharpness
>>> cks = ComputeKeskarSharpness(model_final_ckpt, optimizer, criterion, trainloader, epsilon, lr, max_steps)
>>> width = cks.compute_sharpness()

Details about args:

  • model_final_ckpt: model loaded with the saved checkpoint after final training step
  • optimizer : optimizer to use for projected gradient ascent ( SGD, Adam )
  • criterion : criterion for computing loss (e.g. torch.nn.CrossEntropyLoss)
  • trainloader : iterator over the training dataset (torch.utils.data.DataLoader)
  • epsilon : epsilon value determines the local boundary around which minima witdh is computed (Default value : 1e-4)
  • lr : lr for the optimizer to perform projected gradient ascent ( Default: 0.001)
  • max_steps : max steps to compute the width (Default: 1000). Setting it too low could lead to the gradient ascent method not converging to an optimal point.

The above default values have been chosen after tuning and observing the loss values of projected gradient ascent on Cifar-10 with ResNet-18 and SGD-Momentum optimizer, as mentioned in our paper. The values may vary for experiments with other optimizers/datasets/models. Please tune them for optimal convergence.

  • Acknowledgements: We would like to thank Harshay Shah (https://github.com/harshays) for his helpful discussions for computing the width of the minima.

Citation

Please cite our paper in your publications if you use our work:

@article{iyer2020wideminima,
  title={Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule},
  author={Iyer, Nikhil and Thejas, V and Kwatra, Nipun and Ramjee, Ramachandran and Sivathanu, Muthian},
  journal={arXiv preprint arXiv:2003.03977},
  year={2020}
}
  • Note: This work was done during an internship at Microsoft Research India
Owner
Nikhil Iyer
Studied at BITS-Pilani Hyderabad Campus. AI Research @ Jio-Haptik. Ex Microsoft Research India
Nikhil Iyer
Robust Consistent Video Depth Estimation

[CVPR 2021] Robust Consistent Video Depth Estimation This repository contains Python and C++ implementation of Robust Consistent Video Depth, as descr

Facebook Research 213 Dec 17, 2022
The code for "Deep Level Set for Box-supervised Instance Segmentation in Aerial Images".

Deep Levelset for Box-supervised Instance Segmentation in Aerial Images Wentong Li, Yijie Chen, Wenyu Liu, Jianke Zhu* This code is based on MMdetecti

sunshine.lwt 112 Jan 05, 2023
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Jan 01, 2023
Official implementation of MSR-GCN (ICCV 2021 paper)

MSR-GCN Official implementation of MSR-GCN: Multi-Scale Residual Graph Convolution Networks for Human Motion Prediction (ICCV 2021 paper) [Paper] [Sup

LevonDang 42 Nov 07, 2022
A more easy-to-use implementation of KPConv based on PyTorch.

A more easy-to-use implementation of KPConv This repo contains a more easy-to-use implementation of KPConv based on PyTorch. Introduction KPConv is a

Zheng Qin 36 Dec 29, 2022
September-Assistant - Open-source Windows Voice Assistant

September - Windows Assistant September is an open-source Windows personal assis

The Nithin Balaji 9 Nov 22, 2022
Deep Learning Based Fasion Recommendation System for Ecommerce

Project Name: Fasion Recommendation System for Ecommerce A Deep learning based streamlit web app which can recommened you various types of fasion prod

BAPPY AHMED 13 Dec 13, 2022
这是一个利用facenet和retinaface实现人脸识别的库,可以进行在线的人脸识别。

Facenet+Retinaface:人脸识别模型在Keras当中的实现 目录 注意事项 Attention 所需环境 Environment 文件下载 Download 预测步骤 How2predict 参考资料 Reference 注意事项 该库中包含了两个网络,分别是retinaface和fa

Bubbliiiing 31 Nov 15, 2022
Ontologysim: a Owlready2 library for applied production simulation

Ontologysim: a Owlready2 library for applied production simulation Ontologysim is an open-source deep production simulation framework, with an emphasi

10 Nov 30, 2022
Security evaluation module with onnx, pytorch, and SecML.

🚀 🐼 🔥 PandaVision Integrate and automate security evaluations with onnx, pytorch, and SecML! Installation Starting the server without Docker If you

Maura Pintor 11 Apr 12, 2022
Pytorch implementation of forward and inverse Haar Wavelets 2D

Pytorch implementation of forward and inverse Haar Wavelets 2D

Sergei Belousov 9 Oct 30, 2022
Help you understand Manual and w/ Clutch point while driving.

简体中文 forza_auto_gear forza_auto_gear is a tool for Forza Horizon 5. It will help us understand the best gear shift point using Manual or w/ Clutch in

15 Oct 08, 2022
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
Python implementation of "Single Image Haze Removal Using Dark Channel Prior"

##Dependencies pillow(~2.6.0) Numpy(~1.9.0) If the scripts throw AttributeError: __float__, make sure your pillow has jpeg support e.g. try: $ sudo ap

Joyee Cheung 73 Dec 20, 2022
Supervised domain-agnostic prediction framework for probabilistic modelling

A supervised domain-agnostic framework that allows for probabilistic modelling, namely the prediction of probability distributions for individual data

The Alan Turing Institute 112 Oct 23, 2022
RodoSol-ALPR Dataset

RodoSol-ALPR Dataset This dataset, called RodoSol-ALPR dataset, contains 20,000 images captured by static cameras located at pay tolls owned by the Ro

Rayson Laroca 45 Dec 15, 2022
Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative Adversarial Neural Networks

ForecastingNonverbalSignals This is the implementation for the paper Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative A

1 Feb 10, 2022
SmallInitEmb - LayerNorm(SmallInit(Embedding)) in a Transformer to improve convergence

SmallInitEmb LayerNorm(SmallInit(Embedding)) in a Transformer I find that when t

PENG Bo 11 Dec 25, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022
DAN: Unfolding the Alternating Optimization for Blind Super Resolution

DAN-Basd-on-Openmmlab DAN: Unfolding the Alternating Optimization for Blind Super Resolution We reproduce DAN via mmediting based on open-sourced code

AlexZou 72 Dec 13, 2022