Code for visualizing the loss landscape of neural nets

Overview

Visualizing the Loss Landscape of Neural Nets

This repository contains the PyTorch code for the paper

Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer and Tom Goldstein. Visualizing the Loss Landscape of Neural Nets. NIPS, 2018.

An interactive 3D visualizer for loss surfaces has been provided by telesens.

Given a network architecture and its pre-trained parameters, this tool calculates and visualizes the loss surface along random direction(s) near the optimal parameters. The calculation can be done in parallel with multiple GPUs per node, and multiple nodes. The random direction(s) and loss surface values are stored in HDF5 (.h5) files after they are produced.

Setup

Environment: One or more multi-GPU node(s) with the following software/libraries installed:

Pre-trained models: The code accepts pre-trained PyTorch models for the CIFAR-10 dataset. To load the pre-trained model correctly, the model file should contain state_dict, which is saved from the state_dict() method. The default path for pre-trained networks is cifar10/trained_nets. Some of the pre-trained models and plotted figures can be downloaded here:

Data preprocessing: The data pre-processing method used for visualization should be consistent with the one used for model training. No data augmentation (random cropping or horizontal flipping) is used in calculating the loss values.

Visualizing 1D loss curve

Creating 1D linear interpolations

The 1D linear interpolation method [1] evaluates the loss values along the direction between two minimizers of the same network loss function. This method has been used to compare the flatness of minimizers trained with different batch sizes [2]. A 1D linear interpolation plot is produced using the plot_surface.py method.

mpirun -n 4 python plot_surface.py --mpi --cuda --model vgg9 --x=-0.5:1.5:401 --dir_type states \
--model_file cifar10/trained_nets/vgg9_sgd_lr=0.1_bs=128_wd=0.0_save_epoch=1/model_300.t7 \
--model_file2 cifar10/trained_nets/vgg9_sgd_lr=0.1_bs=8192_wd=0.0_save_epoch=1/model_300.t7 --plot
  • --x=-0.5:1.5:401 sets the range and resolution for the plot. The x-coordinates in the plot will run from -0.5 to 1.5 (the minimizers are located at 0 and 1), and the loss value will be evaluated at 401 locations along this line.
  • --dir_type states indicates the direction contains dimensions for all parameters as well as the statistics of the BN layers (running_mean and running_var). Note that ignoring running_mean and running_var cannot produce correct loss values when plotting two solutions togeather in the same figure.
  • The two model files contain network parameters describing the two distinct minimizers of the loss function. The plot will interpolate between these two minima.

VGG-9 SGD, WD=0

Producing plots along random normalized directions

A random direction with the same dimension as the model parameters is created and "filter normalized." Then we can sample loss values along this direction.

mpirun -n 4 python plot_surface.py --mpi --cuda --model vgg9 --x=-1:1:51 \
--model_file cifar10/trained_nets/vgg9_sgd_lr=0.1_bs=128_wd=0.0_save_epoch=1/model_300.t7 \
--dir_type weights --xnorm filter --xignore biasbn --plot
  • --dir_type weights indicates the direction has the same dimensions as the learned parameters, including bias and parameters in the BN layers.
  • --xnorm filter normalizes the random direction at the filter level. Here, a "filter" refers to the parameters that produce a single feature map. For fully connected layers, a "filter" contains the weights that contribute to a single neuron.
  • --xignore biasbn ignores the direction corresponding to bias and BN parameters (fill the corresponding entries in the random vector with zeros).

VGG-9 SGD, WD=0

We can also customize the appearance of the 1D plots by calling plot_1D.py once the surface file is available.

Visualizing 2D loss contours

To plot the loss contours, we choose two random directions and normalize them in the same way as the 1D plotting.

mpirun -n 4 python plot_surface.py --mpi --cuda --model resnet56 --x=-1:1:51 --y=-1:1:51 \
--model_file cifar10/trained_nets/resnet56_sgd_lr=0.1_bs=128_wd=0.0005/model_300.t7 \
--dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn  --plot

ResNet-56

Once a surface is generated and stored in a .h5 file, we can produce and customize a contour plot using the script plot_2D.py.

python plot_2D.py --surf_file path_to_surf_file --surf_name train_loss
  • --surf_name specifies the type of surface. The default choice is train_loss,
  • --vmin and --vmax sets the range of values to be plotted.
  • --vlevel sets the step of the contours.

Visualizing 3D loss surface

plot_2D.py can make a basic 3D loss surface plot with matplotlib. If you want a more detailed rendering that uses lighting to display details, you can render the loss surface with ParaView.

ResNet-56-noshort ResNet-56

To do this, you must

  1. Convert the surface .h5 file to a .vtp file.
python h52vtp.py --surf_file path_to_surf_file --surf_name train_loss --zmax  10 --log

This will generate a VTK file containing the loss surface with max value 10 in the log scale.

  1. Open the .vtp file with ParaView. In ParaView, open the .vtp file with the VTK reader. Click the eye icon in the Pipeline Browser to make the figure show up. You can drag the surface around, and change the colors in the Properties window.

  2. If the surface appears extremely skinny and needle-like, you may need to adjust the "transforming" parameters in the left control panel. Enter numbers larger than 1 in the "scale" fields to widen the plot.

  3. Select Save screenshot in the File menu to save the image.

Reference

[1] Ian J Goodfellow, Oriol Vinyals, and Andrew M Saxe. Qualitatively characterizing neural network optimization problems. ICLR, 2015.

[2] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. ICLR, 2017.

Citation

If you find this code useful in your research, please cite:

@inproceedings{visualloss,
  title={Visualizing the Loss Landscape of Neural Nets},
  author={Li, Hao and Xu, Zheng and Taylor, Gavin and Studer, Christoph and Goldstein, Tom},
  booktitle={Neural Information Processing Systems},
  year={2018}
}
Owner
Tom Goldstein
Tom Goldstein
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

121 Nov 05, 2022
Bias and Fairness Audit Toolkit

The Bias and Fairness Audit Toolkit Aequitas is an open-source bias audit toolkit for data scientists, machine learning researchers, and policymakers

Data Science for Social Good 513 Jan 06, 2023
Code for visualizing the loss landscape of neural nets

Visualizing the Loss Landscape of Neural Nets This repository contains the PyTorch code for the paper Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer

Tom Goldstein 2.2k Dec 30, 2022
Visual analysis and diagnostic tools to facilitate machine learning model selection.

Yellowbrick Visual analysis and diagnostic tools to facilitate machine learning model selection. What is Yellowbrick? Yellowbrick is a suite of visual

District Data Labs 3.9k Dec 30, 2022
Making decision trees competitive with neural networks on CIFAR10, CIFAR100, TinyImagenet200, Imagenet

Neural-Backed Decision Trees · Site · Paper · Blog · Video Alvin Wan, *Lisa Dunlap, *Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah

Alvin Wan 556 Dec 20, 2022
Implementation of linear CorEx and temporal CorEx.

Correlation Explanation Methods Official implementation of linear correlation explanation (linear CorEx) and temporal correlation explanation (T-CorEx

Hrayr Harutyunyan 34 Nov 15, 2022
JittorVis - Visual understanding of deep learning model.

JittorVis - Visual understanding of deep learning model.

182 Jan 06, 2023
Quickly and easily create / train a custom DeepDream model

Dream-Creator This project aims to simplify the process of creating a custom DeepDream model by using pretrained GoogleNet models and custom image dat

56 Jan 03, 2023
🎆 A visualization of the CapsNet layers to better understand how it works

CapsNet-Visualization For more information on capsule networks check out my Medium articles here and here. Setup Use pip to install the required pytho

Nick Bourdakos 387 Dec 06, 2022
Delve is a Python package for analyzing the inference dynamics of your PyTorch model.

Delve is a Python package for analyzing the inference dynamics of your PyTorch model.

Delve 73 Dec 12, 2022
Using / reproducing ACD from the paper "Hierarchical interpretations for neural network predictions" 🧠 (ICLR 2019)

Hierarchical neural-net interpretations (ACD) 🧠 Produces hierarchical interpretations for a single prediction made by a pytorch neural network. Offic

Chandan Singh 111 Jan 03, 2023
Python Library for Model Interpretation/Explanations

Skater Skater is a unified framework to enable Model Interpretation for all forms of model to help one build an Interpretable machine learning system

Oracle 1k Dec 27, 2022
ModelChimp is an experiment tracker for Deep Learning and Machine Learning experiments.

ModelChimp What is ModelChimp? ModelChimp is an experiment tracker for Deep Learning and Machine Learning experiments. ModelChimp provides the followi

ModelChimp 124 Dec 21, 2022
L2X - Code for replicating the experiments in the paper Learning to Explain: An Information-Theoretic Perspective on Model Interpretation.

L2X Code for replicating the experiments in the paper Learning to Explain: An Information-Theoretic Perspective on Model Interpretation at ICML 2018,

Jianbo Chen 113 Sep 06, 2022
Model analysis tools for TensorFlow

TensorFlow Model Analysis TensorFlow Model Analysis (TFMA) is a library for evaluating TensorFlow models. It allows users to evaluate their models on

1.2k Dec 26, 2022
Lucid library adapted for PyTorch

Lucent PyTorch + Lucid = Lucent The wonderful Lucid library adapted for the wonderful PyTorch! Lucent is not affiliated with Lucid or OpenAI's Clarity

Lim Swee Kiat 520 Dec 26, 2022
Python implementation of R package breakDown

pyBreakDown Python implementation of breakDown package (https://github.com/pbiecek/breakDown). Docs: https://pybreakdown.readthedocs.io. Requirements

MI^2 DataLab 41 Mar 17, 2022
An intuitive library to add plotting functionality to scikit-learn objects.

Welcome to Scikit-plot Single line functions for detailed visualizations The quickest and easiest way to go from analysis... ...to this. Scikit-plot i

Reiichiro Nakano 2.3k Dec 31, 2022
GNNLens2 is an interactive visualization tool for graph neural networks (GNN).

GNNLens2 is an interactive visualization tool for graph neural networks (GNN).

Distributed (Deep) Machine Learning Community 143 Jan 07, 2023
A library for debugging/inspecting machine learning classifiers and explaining their predictions

ELI5 ELI5 is a Python package which helps to debug machine learning classifiers and explain their predictions. It provides support for the following m

2.6k Dec 30, 2022