Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian processes via gpytorch #782

Merged
merged 77 commits into from
Oct 9, 2021
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
72d227b
[WIP] First working implementation of GPs
Mar 26, 2021
df0413e
Minor changes
Mar 28, 2021
108b434
Don't reinitialize uninitialized net bc set_params
Mar 28, 2021
a880987
Move cb_params update, improve comment
Mar 28, 2021
ab5f973
Update notebook with a warning on init likelihood
Mar 29, 2021
e572d65
Remove unnecessary imports
Mar 29, 2021
c6e92a5
Simplify initialize_* methods
Mar 29, 2021
15d6fa8
Add first few unit tests
Mar 29, 2021
0b19538
Add unit test for set_params on uninitialized net
Mar 29, 2021
d6e1bca
Print only when verbose
Mar 30, 2021
c87c932
Remove init code related to likelihood
Mar 30, 2021
0cda3ef
Further clean up of set_params re-initialization
Mar 30, 2021
cff5edb
Add more tests for re-initialization logic
Mar 30, 2021
6c50ef7
Rework logic of creating custom modules/optimizers
Mar 31, 2021
f02738a
Add battery of tests for custom modules/optimizers
Mar 31, 2021
435fc75
Implement changes to make tests pass
Apr 1, 2021
52bdefa
[WIP] Update CHANGES
Apr 1, 2021
eddd5aa
Merge branch 'changed/refactor-init-more-consistency-custom-modules' …
Apr 1, 2021
c471c78
Simplify implementation based on refactoring
Apr 1, 2021
be4c035
[WIP] Document an edge case not covered yet
Apr 1, 2021
f20982d
Remove _PYTORCH_COMPONENTS global
Apr 1, 2021
5480b8f
Update documentation reflecting the changes
Apr 2, 2021
7454e85
All optimizers perform updates automatically
Apr 3, 2021
e1d3c2f
Address reviewer comments
Apr 5, 2021
5510ef0
Merge branch 'changed/refactor-init-more-consistency-custom-modules' …
Apr 7, 2021
2eff978
Further updates based on new skorch refactoring
Apr 7, 2021
1bec3d3
Fix corner case with pre-initialized modules
Apr 8, 2021
ee23283
Merge branch 'changed/refactor-init-more-consistency-custom-modules' …
Apr 8, 2021
968da8c
Activate test about initialization message
Apr 8, 2021
0889eea
Extend test coverage, fix a typo
Apr 9, 2021
c6fb0aa
Custom modules are set to train/eval mode
Apr 9, 2021
b6fb645
Merge branch 'changed/refactor-init-more-consistency-custom-modules' …
Apr 9, 2021
66b2e80
Update docs about train/eval mode
Apr 9, 2021
9633b9e
Update notebook
Apr 9, 2021
d14c1e2
Move tests around, add comment about multioutput
Apr 9, 2021
1547600
Complete docstrings
Apr 10, 2021
44069bc
Complete entries in CHANGES.md
Apr 10, 2021
4cf1c26
Merge branch 'changed/refactor-init-more-consistency-custom-modules' …
Apr 10, 2021
fcc0c06
Complete docs, docstrings, fix linting
Apr 12, 2021
d90f6e6
check_is_fitted also checks for likelihood_
Apr 12, 2021
7784014
Fix a bug when likelihood/module are initialized
Apr 12, 2021
317efde
Update notebook
Apr 13, 2021
6a7c6b1
Update README
Apr 13, 2021
8cad813
Add documentation rst files
Apr 18, 2021
f6e8647
Merge branch 'master' into changed/refactor-init-more-consistency-cus…
BenjaminBossan Apr 23, 2021
4ca941d
Reviewer comment: Consider virtual params
Apr 24, 2021
1a33aec
Reviewer comment: Docs: No need to return self
Apr 24, 2021
bb4e573
Reviewer comment: Docs: explain NeuralNet.predict
Apr 24, 2021
d787d4a
Reviewer comment: Docs: When not calling super
Apr 24, 2021
a61a4c7
Reviewer comment: get_all_learnable_params
Apr 24, 2021
64b3380
Reviewer comment: facilitate module initialization
Apr 24, 2021
eee2922
Merge branch 'changed/refactor-init-more-consistency-custom-modules' …
Apr 24, 2021
8955975
Merge branch 'master' into feature/gpytorch-integration-copy
Jun 13, 2021
58f1e58
Fix a bug that led to double-registration
Jun 13, 2021
e22f59d
Merge branch 'bugfix/module-double-registration-after-clone' into fea…
Jun 13, 2021
307ebf7
Increment gpytorch minimum version
Jun 16, 2021
59d5da7
[WIP] Try to fix some tests
Jun 16, 2021
b8c061a
Fix duplicate parameter bug
Jun 20, 2021
969218b
Merge branch 'master' into feature/gaussian-processes-via-gpytorch
BenjaminBossan Jun 20, 2021
5c9396e
Revert changes in test_net.py
Jun 20, 2021
5fce890
Merge branch 'feature/gaussian-processes-via-gpytorch' of https://git…
Jun 20, 2021
df6c6be
Bump gpytorch version to 1.5
Jun 27, 2021
3795c19
Fix failing test caused by distribution shape
Jun 27, 2021
b3c5908
For testing, exclude Python 3.6, PyTorch 1.7.1
Jul 4, 2021
e97dc14
For testing, exclude Python 3.8 PyTorch 1.7.1
Jul 4, 2021
d85597f
Skip gpytorch tests for pytorch 1.7.1
Jul 4, 2021
02d1b6f
Modify pytorch version check
Jul 4, 2021
8befe3a
Reviewer comment: Use pytest.mark.skipif
BenjaminBossan Aug 15, 2021
cfa993e
Comment out code that is not currently needed
Aug 15, 2021
e3cd39b
Improve documentation example for GPs
Aug 15, 2021
663ee3d
Reviewer comments: some improvements to notebook
Aug 29, 2021
454e898
Use set_train_data for exact GPs
Aug 29, 2021
e793f93
Address reviewer comments by Jacob Gardner
Oct 3, 2021
9670c9c
Address reviewer comments by Immanuel Bayer
Oct 3, 2021
c1a1ace
Fix typo in README
Oct 9, 2021
cdc82a8
Add entry about GPs to CHANGES
Oct 9, 2021
777914a
Merge branch 'master' into feature/gaussian-processes-via-gpytorch
BenjaminBossan Oct 9, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Resources

- `Documentation <https://skorch.readthedocs.io/en/latest/?badge=latest>`_
- `Source Code <https://github.com/skorch-dev/skorch/>`_
- `Instattion <https://github.com/skorch-dev/skorch#installation>`_

========
Examples
Expand Down Expand Up @@ -127,6 +128,7 @@ skorch also provides many convenient features, among others:
- `Parameter freezing/unfreezing <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.Freezer>`_
- `Progress bar <https://skorch.readthedocs.io/en/stable/callbacks.html#skorch.callbacks.ProgressBar>`_ (for CLI as well as jupyter)
- `Automatic inference of CLI parameters <https://github.com/skorch-dev/skorch/tree/master/examples/cli>`_
- `Integration with GPyTorch for Gaussian Processes <https://skorch.readthedocs.io/en/latest/user/probabilistic.html>`_

============
Installation
Expand Down
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@

intersphinx_mapping = {
'pytorch': ('https://pytorch.org/docs/stable/', None),
'sklearn': ('http://scikit-learn.org/stable/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'python': ('https://docs.python.org/3', None),
'mlflow': ('https://mlflow.org/docs/latest/', None),
'gpytorch': ('https://docs.gpytorch.ai/en/stable/', None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down Expand Up @@ -118,7 +119,7 @@
# html_theme_options = {}

def setup(app):
app.add_stylesheet('css/my_theme.css')
app.add_css_file('css/my_theme.css')

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
Expand Down
5 changes: 3 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ User's Guide
user/callbacks
user/dataset
user/save_load
user/probabilistic
user/history
user/toy
user/helper
Expand Down Expand Up @@ -82,5 +83,5 @@ Indices and tables
* :ref:`search`


.. _pytorch: http://pytorch.org/
.. _sklearn: http://scikit-learn.org/
.. _pytorch: https://pytorch.org/
.. _sklearn: https://scikit-learn.org/
5 changes: 5 additions & 0 deletions docs/probabilistic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
skorch.probabilistic
====================

.. automodule:: skorch.probabilistic
:members:
1 change: 1 addition & 0 deletions docs/skorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ skorch
helper
history
net
probabilistic
regressor
scoring
toy
Expand Down
193 changes: 193 additions & 0 deletions docs/user/probabilistic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
==================
Gaussian Processes
==================

skorch integrates with GPyTorch_ to make it easy to train Gaussian Process (GP)
models. You should already know how Gaussian Processes work. Please refer to
other resources if you want to learn about them, this section assumes
familiarity with the concept.

GPyTorch adopts many patterns from PyTorch, thus making it easy to pick up for
seasoned PyTorch users. Similarly, the skorch GPyTorch integration should look
familiar to seasoned skorch users. However, GPs are a different beast than the
more common, non-probabilistic machine learning techniques. It is important to
understand the basic concepts before using them in practice.

Installation
------------

In addition to the normal skorch dependencies and PyTorch, you need to install
GPyTorch as well. It wasn't added as a normal dependency since most users
probably are not interested in using skorch for GPs. To install GPyTorch, use
either pip or conda:

.. code:: bash

# using pip
pip install -U gpytorch
# using conda
conda install gpytorch -c gpytorch


Examples
--------

Exact Gaussian Processes
^^^^^^^^^^^^^^^^^^^^^^^^

Same as GPyTorch, skorch supports exact and approximate Gaussian Processes
regression. For exact GPs, use the
:class:`~skorch.probabilistic.ExactGPRegressor`. The likelihood has to be a
:class:`~gpytorch.likelihoods.GaussianLikelihood` and the criterion
:class:`~gpytorch.mlls.ExactMarginalLogLikelihood`, but those are the defaults
and thus don't need to be specified. For exact GPs, the module needs to be an
:class:`~gpytorch.models.ExactGP`. For this example, we use a simple RBF kernel.

.. code:: python

import gpytorch
from skorch.probabilistic import ExactGPRegressor

class RbfModule(gpytorch.models.ExactGP):
def __init__(likelihood, self):
# detail: We don't set train_inputs and train_targets here skorch because
# will take care of that.
super().__init__()
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.RBFKernel()

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

gpr = ExactGPRegressor(RbfModule)
gpr.fit(X_train, y_train)
y_pred = gpr.predict(X_test)

As you can see, this almost looks like a normal skorch regressor with a normal
PyTorch module. We can fit as normal using the ``fit`` method and predict using
the ``predict`` method.

Inside the module, we determine the mean by using a mean function (just constant
in this case) and the covariance matrix using the RBF kernel function. You
should know about mean and kernel functions already. Having the mean and
covariance matrix, we assume that the output distribution is a multivariate
normal function, since exact GPs rely on this assumption. We could send the
``x`` through an MLP for `Deep Kernel Learning
<https://docs.gpytorch.ai/en/stable/examples/06_PyTorch_NN_Integration_DKL/index.html>`_
but left it out to keep the example simple.

One major difference to usual deep learning models is that we actually predict a
distribution, not just a point estimate. That means that if we choose an
appropriate model that fits the data well, we can express the **uncertainty** of
the model:

.. code:: python

y_pred, y_std = gpr.predict(X, return_std=True)
lower_conf_region = y_pred - y_std
upper_conf_region = y_pred + y_std

Here we not only returned the mean of the prediction, ``y_pred``, but also its
standard deviation, ``y_std``. This tells us how uncertain the model is about
its prediction. E.g., it could be the case that the model is fairly certain when
*interpolating* between data points but uncertain about *extrapolating*. This is
not possible to know when models only learn point predictions.

The obtain the confidence region, you can also use the ``confidence_region``
method:

.. code:: python

# 1 standard deviation
lower, upper = gpr.confidence_region(X, sigmas=1)

# 2 standard deviation, the default
lower, upper = gpr.confidence_region(X, sigmas=2)

Furthermore, a GP allows you to sample from the distribution even *before
fitting* it. The GP needs to be initialized, however:

.. code:: python

gpr = ExactGPRegressor(...)
gpr.initialize()
samples = gpr.sample(X, n_samples=100)

By visualizing the samples and comparing them to the true underlying
distribution of the target, you can already get a feel about whether the model
you built is capable of generating the distribution of the target. If fitting
takes a long time, it is therefore recommended to check the distribution first,
otherwise you may try to fit a model that is incapable of generating the true
distribution and waste a lot of time.

Approximate Gaussian Processes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For some situations, fitting an exact GP might be infeasible, e.g. because the
distribution is not Gaussian or because you want to perform stochastic
optimization with mini-batches. For this, GPyTorch provides facilities to train
variational and approximate GPs. The module should inherit from
:class:`~gpytorch.models.ApproximateGP` and should define a *variational
strategy*. From the skorch side of things, use
:class:`~skorch.probabilistic.GPRegressor`.

.. code:: python

import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from skorch.probabilistic import GPRegressor

class VariationalModule(ApproximateGP):
def __init__(self, inducing_points):
variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
variational_strategy = VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True,
)
super().__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

X, y = get_data(...)
X_incuding = X[:100]
X_train, y_train = X[100:], y[100:]
num_training_samples = len(X_train)

gpr = GPRegressor(
VariationalModule,
module__inducing_points=X_inducing,
criterion__num_data=num_training_samples,
)

gpr.fit(X_train, y_train)
y_pred = gpr.predict(X_train)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The jupyter notebook you sent me currently highlights gpytorch.settings.fast_pred_samples. when calling sample. That setting won't actually do anything unless you are using KISS-GP. A setting that definitely will make a perf difference is wrapping predict in gpytorch.settings.fast_pred_var() though, assuming gpytorch.settings.skip_posterior_variances() isn't also on (see my comment about that below).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that, thanks for clarifying. I will remove the usage of gpytorch.settings.fast_pred_var in the notebook.


As you can see, the variational strategy requires us to use inducing points. We
split off 100 of our training data samples to use as inducing points, assuming
that they are representative of the whole distribution. Apart from this, there
is basically no difference to using exact GP regression.

Finally, skorch also provides :class:`~skorch.probabilistic.GPBinaryClassifier`
for binary classification with GPs. It uses a Bernoulli likelihood by default.
However, using GPs for classification is not very common, GPs are most commonly
used for regression tasks where data points have a known relationship to each
other (e.g. in time series forecasts).

Multiclass classification is not currently provided, but you can use
:class:`~skorch.probabilistic.GPBinaryClassifier` in conjunction with
:class:`~sklearn.multiclass.OneVsRestClassifier` to achieve the same result.

Further examples
----------------

To see all of this in action, we provide a notebook that shows using skorch with GPs on real world data: `Gaussian Processes notebook <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb)>`_.

.. _GPyTorch: https://gpytorch.ai/
2 changes: 2 additions & 0 deletions docs/user/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ The following are examples and notebooks on how to use skorch.
* `Seq2Seq Translation using skorch <https://github.com/skorch-dev/skorch/tree/master/examples/translation>`_ - Translation with a seqeuence to sequence network.

* `Advanced Usage <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb>`_ - Dives deep into the inner works of skorch. `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb>`_

* `Gaussian Processes <https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb>`_ - Train Gaussian Processes with the help of GPyTorch `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Gaussian_Processes.ipynb>`_
Loading