-
Notifications
You must be signed in to change notification settings - Fork 5
Bayesian Neural Networks 101
(Deep) neural networks (that we all know and love) work, in the most general sense, by (1) defining a loss function over input-output space, (2) computing this quantity with respect to observed data points in this space, and then (3) optimizing the network's parameters such that this loss is minimized.
Bayesian neural networks (BNNs) operate differently. Instead of finding the best set of parameters, we want to collect many sets of parameters, whereby a set of parameters is more frequently collected if it explains the data (and our prior beliefs) better. In statistical terms, we treat the network's parameters as random variables (Bayesian POV) instead of an unknown quantity for which a true value exists (frequentist POV). Accordingly, we want to find the probability distribution of these parameters/random variables, conditional on our observed data.
The worldview of "parameter as random variable" allows us to naturally incorporate the notion of uncertainty. There are many sources of uncertainty inherent to the process of learning from data – for example, randomness due to environmental noise, or model uncertainty arising from the fact that many sets of parameters can explain the data equally well. Given these sources of uncertainty, it makes sense that we are trying to learn an entire distribution over network paramaters rather than optimizing for the "best" set. Having access to predictive uncertainty is important in practice, as it prevents the model from making overconfident decisions during test time. This is particularly crucial in high-stake applications like healthcare or robotics, where making the wrong decisions from data can be costly.
There are many ways to obtain a distribution over parameters. One such way is the Bayesian framework, which relies on Bayes' Rule to obtain a distribution whereby the probability of some set of parameters is proportional not only to how well it explains observed data, but also to how well it adheres to our prior beliefs about these parameters. Accordingly, this distribution is known as the posterior distribution. The process of obtaining the posterior is known as (Bayesian) inference. The Metropolis-Hastings algorithm and Gibbs sampling are examples of inference methods that you may have heard of. They don't work well for BNNs since neural network parameters are high-dimensional spaces, and we have to resort to specialized algorithms to do so. Once we have learnt (or inferred) the posterior, we can use it to do all the things ordinary neural networks can do, and more.
Like every other learning model, BNNs are limited in certain ways, too:
- Performing inference to capture the true posterior in such high-dimensional spaces is an immensely difficult task (like many high-dimensional problems) and we are often learning approximations of these posteriors (such as variational inference).
- Even with approximations, performing inference is computationally costly, especially for large network architectures.
- Generalizing BNNs to vast and complex architectures beyond multi-layered perceptrons (such as transformers or deep CNNs) is challenging.
- Defining an appropriate prior in high-dimensional parameter space is non-obvious and non-trivial, but important to get right if we want to obtain high-quality posteriors.
These are all active directions of research in the BNN community. Our research focuses, in particular, on this last limitation.