Skip to content

Code Repo for "Subspace Inference for Bayesian Deep Learning"

License

Notifications You must be signed in to change notification settings

wjmaddox/drbayes

Repository files navigation

Subspace Inference for Bayesian Deep Learning

This repository contains a PyTorch code for the subspace inference method introduced in the paper

Subspace Inference for Bayesian Deep Learning

by Pavel Izmailov, Wesley Maddox, Polina Kirichenko, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson

Introduction

For deep neural network models, exact Bayesian inference is intractable, and existing approximate inference methods suffer from many limitations, largely due to the high dimensionality of the parameter space. In subspace inference we address this problem by performing inference in low-dimensional subspaces of the parameter space. In the paper, we show how to construct very low-dimensional subspaces that can contain diverse high performing models. Bayesian inference within such subspaces is much easier than in the original parameter space and leads to strong results.

At a high level, our method proceeds as follows:

  1. We construct a low dimensional subspace of the parameter space
  2. We approximate the posterior distribution over the parameters in this subspace
  3. We sample from the approximate posterior in the subspace and perform Bayesian model averaging.

For example, we can achieve state-of-the-art results for a WideResNet with 36 million parameters by performing approximate inference in only a 5-dimensional subspace.

If you use our work, please cite the following paper:

@article{izmailov_subspace_2019,
  title={Subspace Inference for Bayesian Deep Learning},
  author={Izmailov, Pavel and Maddox, Wesley and Kirichenko, Polina and Garipov, Timur and Vetrov, Dmitry and Wilson, Andrew Gordon},
  journal={Uncertainty in Artificial Intelligence (UAI)},
  year={2019}
}

Installation

You can install the package by running the following command

python setup.py

Usage

We provide the scripts and example commands to reproduce the experiments from the paper.

Subspace Construction and Inference Procedures

In the paper we show how to construct random, PCA, and mode-connected subspaces, over which we can perform Bayesian inference. The PCA subspace will often be most practical, in terms of a runtime-accuracy trade-off. Once the subspace is constructed, we consider various approaches to inference within the subspace, such as MCMC and variational methods. The following examples show how to call subspace inference with your choice of subspace and inference procedure.

Visualizing Regression Uncertainty

We provide a jupyter notebook in which we show how to apply subspace inference to a regression problem, using different types of subspaces and approximate inference methods.

Image Classification

The folder experiments/cifar_exps contains the implementation of subspace inference for CIFAR-10 and CIFAR-100 datasets. To run subspace inference, you need to first pre-train the subspace, and then either run variational inference (VI) or elliptical slice sampling (ESS) in the subspace. For example, you can pre-train a PCA subspace for a VGG-16 by running

python3 experiments/cifar_exps/swag.py --data_path=data --epochs=300 --dataset=CIFAR100 --save_freq=300 \
      --model=VGG16 --lr_init=0.05 --wd=5e-4 --swag --swag_start=161 --swag_lr=0.01 --cov_mat --use_test \
      --dir=ckpts/vgg16_run1

Then, you can run ESS in the constructed subspace using the following command

python3 experiments/cifar_exps/subspace_ess.py --dir=ckpts/vgg16_run1/ess/ --dataset=CIFAR100 \
      --data_path=~/datasets/ --model=VGG16 --use_test --rank=5 --checkpoint=ckpts/vgg16_run1/swag-300.pt \
      --temperature=5000 --prior_std=2

Please refer to the README for more detailed information.

UCI Regression

The folder experiments/uci_exps contains implementations of the subspace inference procedure for UCI regression problems.

Please refer to the README for more detailed information.

References for Code Base

Repositories:

  • Stochastic weight averaging (SWA): PyTorch repo. Many of the base methods and model definitions are built off of this repo.
  • SWA-Gaussian (SWAG): PyTorch repo. our codebase here is based the SWAG repo; SWAG repo also contains the implementations of various baselines fof image classification as well as the experiments on Hessian eigenvalues.
  • Mode-Connectivity: PyTorch repo. We use the scripts from this repository to construct our curve subspaces.
  • Bayesian benchmarks: Repo; the experiments/uci_exps folder is a clone of that repo.

Model implementations:

About

Code Repo for "Subspace Inference for Bayesian Deep Learning"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •