Supporting code for the paper "Dangers of Bayesian Model Averaging under Covariate Shift"

Overview

Dangers of Bayesian Model Averaging under Covariate Shift

This repository contains the code to reproduce the experiments in the paper Dangers of Bayesian Model Averaging under Covariate Shift by Pavel Izmailov, Patrick Nicholson, Sanae Lotfi and Andrew Gordon Wilson.

The code is forked from the Google Research BNN HMC repo.

Introduction

Approximate Bayesian inference for neural networks is considered a robust alternative to standard training, often providing good performance on out-of-distribution data. However, it was recently shown that Bayesian neural networks (BNNs) with high fidelity inference through Hamiltonian Monte Carlo (HMC) provide shockingly poor performance under covariate shift. For example, below we show that a ResNet-20 BNN approximated with HMC underperforms a maximum a-posteriori (MAP) solution by 25% on the pixelate-corrupted CIFAR-10 test set. This result is particularly surprising given that on the in-distribution test data, the BNN outperforms the MAP solution by over 5%. In this work, we seek to understand, further demonstrate, and help remedy this concerning behaviour.

As an example, let us consider a fully-connected network on MNIST. MNIST contains many dead pixels, i.e. pixels near the boundary that are zero for all training images. The corresponding weights in the first layer of the network are always multiplied by zero, and have no effect on the likelihood of the training data. Consequently, in a Bayesian neural network, these weights will be sampled from the prior. A MAP solution on the other hand will set these parameters close to zero. In the animation, we visualize the weights in the first layer of a Bayesian neural network and a MAP solution. For each sample, we show the value of the weight corresponding to the highlighted pixel.

If at test time the data is corrupted, e.g. by Gaussian noise, and the pixels near the boundary of the image are activated, the MAP solution will ignore these pixels, while the predictions of the BNN will be significantly affected.

In the paper, we extend this reasoning to general linear dependencies between input features for both fully connected and convolutional Bayesian neural networks. We also propose EmpCov, a prior based on the empirical covariance of the data which significantly improves robustness of BNNs to covariate shift. We implement EmpCov as well as other priors for Bayesian neural networks in this repo.

Requirements

We use provide a requirements.txt file that can be used to create a conda environment to run the code in this repo:

conda create --name <env> --file requirements.txt

Example set-up using pip:

pip install tensorflow

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.65+cuda112 -f \
https://storage.googleapis.com/jax-releases/jax_releases.html

pip install git+https://github.com/deepmind/dm-haiku
pip install tensorflow_datasets
pip install tabulate
pip install optax

Please see the JAX repo for the latest instructions on how to install JAX on your hardware.

File Structure

The implementations of HMC and other methods forked from the BNN HMC repo are in the bnn_hmc folder. The main training scripts are run_hmc.py for HMC and run_sgd.py for SGD respectively. In the notebooks folder we show examples of how to extract the covariance matrices for EmpCov priors, and evaluate the results under various corruptions.

.
+-- bnn_hmc/
|   +-- core/
|   |   +-- hmc.py (The Hamiltonian Monte Carlo algorithm)
|   |   +-- sgmcmc.py (SGMCMC methods as optax optimizers)
|   |   +-- vi.py (Mean field variational inference)
|   +-- utils/ (Utility functions used by the training scripts)
|   |   +-- train_utils.py (The training epochs and update rules)
|   |   +-- models.py (Models used in the experiments)
|   |   +-- losses.py (Prior and likelihood functions)
|   |   +-- data_utils.py (Loading and pre-processing the data)
|   |   +-- optim_utils.py (Optimizers and learning rate schedules)
|   |   +-- ensemble_utils.py (Implementation of ensembling of predictions)
|   |   +-- metrics.py (Metrics used in evaluation)
|   |   +-- cmd_args_utils.py (Common command line arguments)
|   |   +-- script_utils.py (Common functionality of the training scripts)
|   |   +-- checkpoint_utils.py (Saving and loading checkpoints)
|   |   +-- logging_utils.py (Utilities for logging printing the results)
|   |   +-- precision_utils.py (Controlling the numerical precision)
|   |   +-- tree_utils.py (Common operations on pytree objects)
+-- notebooks/  
|   +-- cnn_robustness_cifar10.ipynb (Creates CIFAR-10 CNN figures used in paper)  
|   +-- mlp_robustness_mnist.ipynb (Creates MNIST MLP figures used in paper)
|   +-- cifar10_cnn_extract_empcov.ipynb (Constructs EmpCov prior covariance matrix for CIFAR-10 CNN)
|   +-- mnist_extract_empcov.ipynb (Constructs EmpCov prior covariance matrices for CIFAR-10 CNN and MLP)
+-- empcov_covs/
|   +-- cifar_cnn_pca_inv_cov.npy (EmpCov inverse prior covariance for CIFAR-10 CNN)
|   +-- mnist_cnn_pca_inv_cov.npy (EmpCov inverse prior covariance for MNIST CNN)
|   +-- mnist_mlp_pca_inv_cov.npy (EmpCov inverse prior covariance for MNIST MLP)
+-- run_hmc.py (HMC training script)
+-- run_sgd.py (SGD training script)

Training Scripts

The training scripts are adapted from the Google Research BNN HMC repo. For completeness, we provide full details about the command line arguments here.

Common command line arguments:

  • seed — random seed
  • dir — training directory for saving the checkpoints and tensorboard logs
  • dataset_name — name of the dataset, e.g. cifar10, cifar100, mnist
  • subset_train_to — number of datapoints to use from the dataset; by default, the full dataset is used
  • model_name — name of the neural network architecture, e.g. lenet, resnet20_frn_swish, cnn_lstm, mlp_regression_small
  • weight_decay — weight decay; for Bayesian methods, weight decay determines the prior variance (prior_var = 1 / weight_decay)
  • temperature — posterior temperature (default: 1)
  • init_checkpoint — path to the checkpoint to use for initialization (optional)
  • tabulate_freq — frequency of tabulate table header logging
  • use_float64 — use float64 precision (does not work on TPUs and some GPUs); by default, we use float32 precision
  • prior_family — type of prior to use; must be one of Gaussian, ExpFNormP, Laplace, StudentT, SumFilterLeNet, EmpCovLeNet or EmpCovMLP; see the next section for more details

Prior Families

In this repo we implement several prior distribution families. Some of the prior families have additional command line arguments specifying the parameters of the prior:

  • Gaussian — iid Gaussian prior centered at 0 with variance equal to 1 / weight_decay
  • Laplace — iid Laplace prior centered at 0 with variance equal to 1 / weight_decay
  • StudentT — iid Laplace prior centered at 0 with studentt_degrees_of_freedom degrees of freedom and scaled by 1 / weight_decay
  • ExpFNormP — iid ExpNorm prior centered at 0 defined in the paper. expfnormp_power specifies the power under the exponent in the prior, and 1 / weight_decay defines the scale of the prior
  • EmpCovLeNet and EmpCovMLPEmpCov priors with the inverse of empirical covariance matrix of the data as a .npy array provided as empcov_invcov_ckpt; empcov_wd allows to rescale the covariance matrix for the first layer.
  • SumFilterLeNetSumFilter prior presented in the paper; 1 / sumfilterlenet_weight_decay determines the prior variance for the sum of the filter weights in the first layer

Some prior types require additional arguments, such as empcov_pca_wd and studentt_degrees_of_freedom; run scripts with --help for full details.

Running HMC

To run HMC, you can use the run_hmc.py training script. Arguments:

  • step_size — HMC step size
  • trajectory_len — HMC trajectory length
  • num_iterations — Total number of HMC iterations
  • max_num_leapfrog_steps — Maximum number of leapfrog steps allowed; meant as a sanity check and should be greater than trajectory_len / step_size
  • num_burn_in_iterations — Number of burn-in iterations (default: 0)

Examples

CNN on CIFAR-10 with different priors:

# Gaussian prior
python3 run_hmc.py --seed=0 --weight_decay=100 --temperature=1. \
  --dir=runs/hmc/cifar10/gaussian/ --dataset_name=cifar10 \
  --model_name=lenet --step_size=3.e-05 --trajectory_len=0.15 \
  --num_iterations=100 --max_num_leapfrog_steps=5300 \
  --num_burn_in_iterations=10

# Laplace prior
python3 run_hmc.py --seed=0 --weight_decay=100 --temperature=1. \
  --dir=runs/hmc/cifar10/laplace --dataset_name=cifar10 \
  --model_name=lenet --step_size=3.e-05 --trajectory_len=0.15 \
  --num_iterations=100 --max_num_leapfrog_steps=5300 \
  --num_burn_in_iterations=10 --prior_family=Laplace

# Gaussian prior, T=0.1
python3  run_hmc.py --seed=0 --weight_decay=3 --temperature=0.01 \
  --dir=runs/hmc/cifar10/lenet/temp --dataset_name=cifar10 \
  --model_name=lenet --step_size=1.e-05 --trajectory_len=0.1 \
  --num_iterations=100 --max_num_leapfrog_steps=10000 \
  --num_burn_in_iterations=10

# EmpCov prior
python3 run_hmc.py --seed=0 --weight_decay=100. --temperature=1. \
  --dir=runs/hmc/cifar10/EmpCov --dataset_name=cifar10 \
  --model_name=lenet --step_size=1.e-4 --trajectory_len=0.157 \ 
  --num_iterations=100 --max_num_leapfrog_steps=2000 \
  --num_burn_in_iterations=10 --prior_family=EmpCovLeNet \
  --empcov_invcov_ckpt=empcov_covs/cifar_cnn_pca_inv_cov.npy \
  --empcov_wd=100.

We ran these commands on a machine with 8 NVIDIA Tesla V-100 GPUs.

MLP on MNIST using different priors:

# Gaussian prior
python3 run_hmc.py --seed=2 --weight_decay=100  \
  --dir=runs/hmc/mnist/gaussian \
  --dataset_name=mnist --model_name=mlp_classification \
  --step_size=1.e-05 --trajectory_len=0.15 \
  --num_iterations=100 --max_num_leapfrog_steps=15500 \
  --num_burn_in_iterations=10

# Laplace prior
python3 run_hmc.py --seed=0 --weight_decay=3.0 \
  --dir=runs/hmc/mnist/laplace --dataset_name=mnist \
  --model_name=mlp_classification --step_size=6.e-05 \
  --trajectory_len=0.9 --num_iterations=100 \
  --max_num_leapfrog_steps=15500 \
  --num_burn_in_iterations=10 --prior_family=Laplace

# Student-T prior
python3 run_hmc.py --seed=0 --weight_decay=10. \
  --dir=runs/hmc/mnist/studentt --dataset_name=mnist \
  --model_name=mlp_classification --step_size=1.e-4 --trajectory_len=0.49 \ 
  --num_iterations=100 --max_num_leapfrog_steps=5000 \
  --num_burn_in_iterations=10 --prior_family=StudentT \
  --studentt_degrees_of_freedom=5.

# Gaussian prior, T=0.1
python3 run_hmc.py --seed=11 --weight_decay=100 \
  --temperature=0.01 --dir=runs/hmc/mnist/temp \
  --dataset_name=mnist --model_name=mlp_classification \
  --step_size=6.3e-07 --trajectory_len=0.015 \
  --num_iterations=100 --max_num_leapfrog_steps=25500 \
  --num_burn_in_iterations=10

# EmpCov prior
python3 run_hmc.py --seed=0 --weight_decay=100 \
  --dir=runs/hmc/mnist/empcov --dataset_name=mnist \
  --model_name=mlp_classification --step_size=1.e-05 \
  --trajectory_len=0.15 --num_iterations=100 \
  --max_num_leapfrog_steps=15500 \
  --num_burn_in_iterations=10 --prior_family=EmpCovMLP \
  --empcov_invcov_ckpt=empcov_covs/mnist_mlp_pca_inv_cov.npy \
  --empcov_wd=100  

This script can be ran on a single GPU or a TPU V3-8.

Running SGD

To run SGD, you can use the run_sgd.py training script. Arguments:

  • init_step_size — Initial SGD step size; we use a cosine schedule
  • num_epochs — total number of SGD epochs iterations
  • batch_size — batch size
  • eval_freq — frequency of evaluation (epochs)
  • save_freq — frequency of checkpointing (epochs)
  • momentum_decay — momentum decay parameter for SGD

Examples

MLP on MNIST:

python3 run_sgd.py --seed=0 --weight_decay=100 --dir=runs/sgd/mnist/ \
  --dataset_name=mnist --model_name=mlp_classification \
  --init_step_size=1e-7 --eval_freq=10 --batch_size=80 \
  --num_epochs=100 --save_freq=100

CNN on CIFAR-10:

python3 run_sgd.py --seed=0 --weight_decay=100. --dir=runs/sgd/cifar10/lenet \
  --dataset_name=cifar10 --model_name=lenet --init_step_size=1e-7 --batch_size=80 \
  --num_epochs=300 --save_freq=300

To train a deep ensemble, we simply train multiple copies of SGD with different random seeds.

Results

We consider the corrupted versions of the MNIST and CIFAR-10 datasets with both fully-connected (mlp_classification) and convolutional (lenet) architectures. Additionally, we consider domain shift problems from MNIST to SVHN and from CIFAR-10 to STL-10. We apply the EmpCov prior to the first layer of Bayesian neural networks (BNNs), and a Gaussian prior to all other layers using the commands in the examples. The following figure shows the results for: deep ensembles, maximum-a-posterior estimate obtained through SGD, BNNs with a Gaussian prior, and BNNs with our novel EmpCov prior. EmpCov prior improves the robustness of BNNs to covariate shift, leading to better results on most corruptions and a competitive performance with deep ensembles for both fully-connected and convolutional architectures.

combined_resolution png-1

Owner
Pavel Izmailov
Pavel Izmailov
Repository for code and dataset for our EMNLP 2021 paper - “So You Think You’re Funny?”: Rating the Humour Quotient in Standup Comedy.

AI-OpenMic Dataset The dataset is available for download via the follwing link. Repository for code and dataset for our EMNLP 2021 paper - “So You Thi

6 Oct 26, 2022
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022
Data Consistency for Magnetic Resonance Imaging

Data Consistency for Magnetic Resonance Imaging Data Consistency (DC) is crucial for generalization in multi-modal MRI data and robustness in detectin

Dimitris Karkalousos 19 Dec 12, 2022
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
UNION: An Unreferenced Metric for Evaluating Open-ended Story Generation

UNION Automatic Evaluation Metric described in the paper UNION: An UNreferenced MetrIc for Evaluating Open-eNded Story Generation (EMNLP 2020). Please

50 Dec 30, 2022
PGPortfolio: Policy Gradient Portfolio, the source code of "A Deep Reinforcement Learning Framework for the Financial Portfolio Management Problem"(https://arxiv.org/pdf/1706.10059.pdf).

This is the original implementation of our paper, A Deep Reinforcement Learning Framework for the Financial Portfolio Management Problem (arXiv:1706.1

Zhengyao Jiang 1.5k Dec 29, 2022
利用Tensorflow实现基于CNN的中文短文本分类

Text Classification with CNN 使用卷积神经网络进行中文文本分类 CNN做句子分类的论文可以参看: Convolutional Neural Networks for Sentence Classification 还可以去读dennybritz大牛的博客:Implemen

Jeremiah 4 Nov 08, 2022
Codes for Causal Semantic Generative model (CSG), the model proposed in "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS-21)

Learning Causal Semantic Representation for Out-of-Distribution Prediction This repository is the official implementation of "Learning Causal Semantic

Chang Liu 54 Dec 01, 2022
Code for "Learning to Regrasp by Learning to Place"

Learning2Regrasp Learning to Regrasp by Learning to Place, CoRL 2021. Introduction We propose a point-cloud-based system for robots to predict a seque

Shuo Cheng (成硕) 18 Aug 27, 2022
MERLOT: Multimodal Neural Script Knowledge Models

merlot MERLOT: Multimodal Neural Script Knowledge Models MERLOT is a model for learning what we are calling "neural script knowledge" -- representatio

Rowan Zellers 190 Dec 22, 2022
Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022

PyCRE Conflict-aware Inference of Python Compatible Runtime Environments with Domain Knowledge Graph, ICSE 2022 Dependencies This project is developed

<a href=[email protected]"> 7 May 06, 2022
Kaggle DSTL Satellite Imagery Feature Detection

Kaggle DSTL Satellite Imagery Feature Detection

Konstantin Lopuhin 206 Oct 29, 2022
Implementation of "DeepOrder: Deep Learning for Test Case Prioritization in Continuous Integration Testing".

DeepOrder Implementation of DeepOrder for the paper "DeepOrder: Deep Learning for Test Case Prioritization in Continuous Integration Testing". Project

6 Nov 07, 2022
OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis

OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis Overview OpenABC-D is a large-scale labeled dataset generate

NYU Machine-Learning guided Design Automation (MLDA) 31 Nov 22, 2022
Simulations for Turring patterns on an apically expanding domain. T

Turing patterns on expanding domain Simulations for Turring patterns on an apically expanding domain. The details about the models and numerical imple

Yue Liu 0 Aug 03, 2021
ICML 21 - Voice2Series: Reprogramming Acoustic Models for Time Series Classification

Voice2Series-Reprogramming Voice2Series: Reprogramming Acoustic Models for Time Series Classification International Conference on Machine Learning (IC

49 Jan 03, 2023
Hard cater examples from Hopper ICLR paper

CATER-h Honglu Zhou*, Asim Kadav, Farley Lai, Alexandru Niculescu-Mizil, Martin Renqiang Min, Mubbasir Kapadia, Hans Peter Graf (*Contact: honglu.zhou

NECLA ML Group 6 May 11, 2021
PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

118 Dec 15, 2022
torchsummaryDynamic: support real FLOPs calculation of dynamic network or user-custom PyTorch ops

torchsummaryDynamic Improved tool of torchsummaryX. torchsummaryDynamic support real FLOPs calculation of dynamic network or user-custom PyTorch ops.

Bohong Chen 1 Jan 07, 2022
Reproduce results and replicate training fo T0 (Multitask Prompted Training Enables Zero-Shot Task Generalization)

T-Zero This repository serves primarily as codebase and instructions for training, evaluation and inference of T0. T0 is the model developed in Multit

BigScience Workshop 253 Dec 27, 2022