Skip to content

Commit

Permalink
Gaussian processes via gpytorch (#782)
Browse files Browse the repository at this point in the history
This is a feature to add support for Gaussian Processes (GPs) via integration
with gpytorch.

Similar to "vanilla" PyTorch, the idea here is that skorch allows the user to
focus on what's important (implementing the mean function and kernel function)
and not to bother with stuff like the training loop, callbacks, etc. This is
probably best illustrated in the accompanying notebook.

GPs are primarily for regression, hence those are the main focus here.
Traditionally, there are "exact" solutions and approximations. skorch will
provide an ExactGPRegressor and a GPRegressor for those two use cases.

On top of that, a GPBinaryClassifier is offered, though I suspect it to be
rarely used. I couldn't get the GPClassifier for multiclass to work, the code is
therefore commented out.

The API is mostly the same as for the normal skorch estimators. There are some
additions to make working with GPs easier:

- predict method takes a return_std argument to return the standard deviation as
  well (as in sklearn's GaussianProcessRegressor; return_cov is not supported)
- sample method to sample for the model
- confidence_region method to get the confidence region

Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
BenjaminBossan and thomasjpfan authored Oct 9, 2021
1 parent 156efe6 commit 6ca5c4e
Show file tree
Hide file tree
Showing 14 changed files with 4,411 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
- Added `MlflowLogger` callback for logging to Mlflow (#769)
- Added `InputShapeSetter` callback for automatically setting the input dimension of the PyTorch module
- Added a new module to support Gaussian Processes through [GPyTorch](https://gpytorch.ai/). To learn more about it, read the [GP documentation](https://skorch.readthedocs.io/en/latest/user/probabilistic.html) or take a look at the [GP notebook](https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb). This feature is experimental, i.e. the API could be changed in the future in a backwards incompatible way.

### Changed

Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Resources

- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
- `Source Code <https://github.com/skorch-dev/skorch/>`_
- `Installation <https://github.com/skorch-dev/skorch#installation>`_

========
Examples
Expand Down Expand Up @@ -127,6 +128,7 @@ skorch also provides many convenient features, among others:
- `Parameter freezing/unfreezing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Freezer>`_
- `Progress bar <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.ProgressBar>`_ (for CLI as well as jupyter)
- `Automatic inference of CLI parameters <https://github.com/skorch-dev/skorch/tree/master/examples/cli>`_
- `Integration with GPyTorch for Gaussian Processes <https://skorch.readthedocs.io/en/latest/user/probabilistic.html>`_

============
Installation
Expand Down
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@

intersphinx_mapping = {
'pytorch': ('https://pytorch.org/docs/stable/', None),
'sklearn': ('http://scikit-learn.org/stable/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/3', None),
'mlflow': ('https://mlflow.org/docs/latest/', None),
'gpytorch': ('https://docs.gpytorch.ai/en/stable/', None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down Expand Up @@ -118,7 +119,7 @@
# html_theme_options = {}

def setup(app):
app.add_stylesheet('css/my_theme.css')
app.add_css_file('css/my_theme.css')

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
Expand Down
5 changes: 3 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ User's Guide
user/callbacks
user/dataset
user/save_load
user/probabilistic
user/history
user/toy
user/helper
Expand Down Expand Up @@ -82,5 +83,5 @@ Indices and tables
* :ref:`search`


.. _pytorch: http://pytorch.org/
.. _sklearn: http://scikit-learn.org/
.. _pytorch: https://pytorch.org/
.. _sklearn: https://scikit-learn.org/
5 changes: 5 additions & 0 deletions docs/probabilistic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
skorch.probabilistic
====================

.. automodule:: skorch.probabilistic
:members:
1 change: 1 addition & 0 deletions docs/skorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ skorch
helper
history
net
probabilistic
regressor
scoring
toy
Expand Down
222 changes: 222 additions & 0 deletions docs/user/probabilistic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
==================
Gaussian Processes
==================

skorch integrates with GPyTorch_ to make it easy to train Gaussian Process (GP)
models. You should already know how Gaussian Processes work. Please refer to
other resources if you want to learn about them, this section assumes
familiarity with the concept.

GPyTorch adopts many patterns from PyTorch, thus making it easy to pick up for
seasoned PyTorch users. Similarly, the skorch GPyTorch integration should look
familiar to seasoned skorch users. However, GPs are a different beast than the
more common, non-probabilistic machine learning techniques. It is important to
understand the basic concepts before using them in practice.

Installation
------------

In addition to the normal skorch dependencies and PyTorch, you need to install
GPyTorch as well. It wasn't added as a normal dependency since most users
probably are not interested in using skorch for GPs. To install GPyTorch, use
either pip or conda:

.. code:: bash
# using pip
pip install -U gpytorch
# using conda
conda install gpytorch -c gpytorch
When to use GPyTorch with skorch
--------------------------------

Here we want to quickly explain when it would be a good idea for you to use
GPyTorch with skorch. There are a couple of offerings in the Python ecosystem
when it comes to Gaussian Processes. We cannot provide an exhaustive list of
pros and cons of each possibility. There are, however, two obvious alternatives
that are worth discussing: using the sklearn_ implementation and using GPyTorch
without skorch.

When to use skorch + GPyTorch over sklearn:

* When you are more familiar with PyTorch than with sklearn
* When the kernels provided by sklearn are not sufficient for your use case and
you would like to implement custom kernels with PyTorch
* When you want to use the rich set of optimizers available in PyTorch
* When sklearn is too slow and you want to use the GPU or scale across machines
* When you like to use the skorch extras, e.g. callbacks

When to use skorch + GPyTorch over pure GPyTorch

* When you're already familiar with skorch and want an easy entry into GPs
* When you like to use the skorch extras, e.g. callbacks and grid search
* When you don't want to bother with writing your own training loop

However, if you are researching GPs and would like to have control over every
detail, using all the rich but very specific featues that GPyTorch has on offer,
it is better to use it directly without skorch.

Examples
--------

Exact Gaussian Processes
^^^^^^^^^^^^^^^^^^^^^^^^

Same as GPyTorch, skorch supports exact and approximate Gaussian Processes
regression. For exact GPs, use the
:class:`~skorch.probabilistic.ExactGPRegressor`. The likelihood has to be a
:class:`~gpytorch.likelihoods.GaussianLikelihood` and the criterion
:class:`~gpytorch.mlls.ExactMarginalLogLikelihood`, but those are the defaults
and thus don't need to be specified. For exact GPs, the module needs to be an
:class:`~gpytorch.models.ExactGP`. For this example, we use a simple RBF kernel.

.. code:: python
import gpytorch
from skorch.probabilistic import ExactGPRegressor
class RbfModule(gpytorch.models.ExactGP):
def __init__(likelihood, self):
# detail: We don't set train_inputs and train_targets here skorch because
# will take care of that.
super().__init__()
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.RBFKernel()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
gpr = ExactGPRegressor(RbfModule)
gpr.fit(X_train, y_train)
y_pred = gpr.predict(X_test)
As you can see, this almost looks like a normal skorch regressor with a normal
PyTorch module. We can fit as normal using the ``fit`` method and predict using
the ``predict`` method.

Inside the module, we determine the mean by using a mean function (just constant
in this case) and the covariance matrix using the RBF kernel function. You
should know about mean and kernel functions already. Having the mean and
covariance matrix, we assume that the output distribution is a multivariate
normal function, since exact GPs rely on this assumption. We could send the
``x`` through an MLP for `Deep Kernel Learning
<https://docs.gpytorch.ai/en/stable/examples/06_PyTorch_NN_Integration_DKL/index.html>`_
but left it out to keep the example simple.

One major difference to usual deep learning models is that we actually predict a
distribution, not just a point estimate. That means that if we choose an
appropriate model that fits the data well, we can express the **uncertainty** of
the model:

.. code:: python
y_pred, y_std = gpr.predict(X, return_std=True)
lower_conf_region = y_pred - y_std
upper_conf_region = y_pred + y_std
Here we not only returned the mean of the prediction, ``y_pred``, but also its
standard deviation, ``y_std``. This tells us how uncertain the model is about
its prediction. E.g., it could be the case that the model is fairly certain when
*interpolating* between data points but uncertain about *extrapolating*. This is
not possible to know when models only learn point predictions.

The obtain the confidence region, you can also use the ``confidence_region``
method:

.. code:: python
# 1 standard deviation
lower, upper = gpr.confidence_region(X, sigmas=1)
# 2 standard deviation, the default
lower, upper = gpr.confidence_region(X, sigmas=2)
Furthermore, a GP allows you to sample from the distribution even *before
fitting* it. The GP needs to be initialized, however:

.. code:: python
gpr = ExactGPRegressor(...)
gpr.initialize()
samples = gpr.sample(X, n_samples=100)
By visualizing the samples and comparing them to the true underlying
distribution of the target, you can already get a feel about whether the model
you built is capable of generating the distribution of the target. If fitting
takes a long time, it is therefore recommended to check the distribution first,
otherwise you may try to fit a model that is incapable of generating the true
distribution and waste a lot of time.

Approximate Gaussian Processes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For some situations, fitting an exact GP might be infeasible, e.g. because the
distribution is not Gaussian or because you want to perform stochastic
optimization with mini-batches. For this, GPyTorch provides facilities to train
variational and approximate GPs. The module should inherit from
:class:`~gpytorch.models.ApproximateGP` and should define a *variational
strategy*. From the skorch side of things, use
:class:`~skorch.probabilistic.GPRegressor`.

.. code:: python
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from skorch.probabilistic import GPRegressor
class VariationalModule(ApproximateGP):
def __init__(self, inducing_points):
variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
variational_strategy = VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True,
)
super().__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
X, y = get_data(...)
X_incuding = X[:100]
X_train, y_train = X[100:], y[100:]
num_training_samples = len(X_train)
gpr = GPRegressor(
VariationalModule,
module__inducing_points=X_inducing,
criterion__num_data=num_training_samples,
)
gpr.fit(X_train, y_train)
y_pred = gpr.predict(X_train)
As you can see, the variational strategy requires us to use inducing points. We
split off 100 of our training data samples to use as inducing points, assuming
that they are representative of the whole distribution. Apart from this, there
is basically no difference to using exact GP regression.

Finally, skorch also provides :class:`~skorch.probabilistic.GPBinaryClassifier`
for binary classification with GPs. It uses a Bernoulli likelihood by default.
However, using GPs for classification is not very common, GPs are most commonly
used for regression tasks where data points have a known relationship to each
other (e.g. in time series forecasts).

Multiclass classification is not currently provided, but you can use
:class:`~skorch.probabilistic.GPBinaryClassifier` in conjunction with
:class:`~sklearn.multiclass.OneVsRestClassifier` to achieve the same result.

Further examples
----------------

To see all of this in action, we provide a notebook that shows using skorch with GPs on real world data: `Gaussian Processes notebook <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb)>`_.

.. _GPyTorch: https://gpytorch.ai/
.. _sklearn: https://scikit-learn.org/stable/modules/gaussian_process.html
2 changes: 2 additions & 0 deletions docs/user/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ The following are examples and notebooks on how to use skorch.
* `Seq2Seq Translation using skorch <https://github.com/skorch-dev/skorch/tree/master/examples/translation>`_ - Translation with a seqeuence to sequence network.

* `Advanced Usage <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb>`_ - Dives deep into the inner works of skorch. `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb>`_

* `Gaussian Processes <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb>`_ - Train Gaussian Processes with the help of GPyTorch `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb>`_
Loading

0 comments on commit 6ca5c4e

Please sign in to comment.