Skip to content

Commit

Permalink
Kernel.output_shape semantics and further RandomProcess refactori…
Browse files Browse the repository at this point in the history
…ng (#652)
  • Loading branch information
marvinpfoertner authored Feb 26, 2022
1 parent db6e7d3 commit 7af755a
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 185 deletions.
5 changes: 5 additions & 0 deletions src/probnum/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def output_shape(self) -> ShapeType:
For scalar-valued function, this is an empty tuple."""
return self._output_shape

@property
def output_ndim(self) -> int:
"""Syntactic sugar for ``len(output_shape)``."""
return self._output_ndim

def __call__(self, x: ArrayLike) -> np.ndarray:
"""Evaluate the function at a given input.
Expand Down
74 changes: 7 additions & 67 deletions src/probnum/randprocs/_gaussian_process.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
"""Gaussian processes."""

from typing import Type, Union

import numpy as np

from probnum import randvars
from probnum.typing import ShapeLike

from . import _random_process, kernels
from .. import _function
from ..typing import ArrayLike

_InputType = Union[np.floating, np.ndarray]
_OutputType = Union[np.floating, np.ndarray]


class GaussianProcess(_random_process.RandomProcess[_InputType, _OutputType]):
class GaussianProcess(_random_process.RandomProcess[ArrayLike, np.ndarray]):
"""Gaussian processes.
A Gaussian process is a continuous stochastic process which if evaluated at a
Expand Down Expand Up @@ -67,70 +61,16 @@ def __init__(
if not isinstance(mean, _function.Function):
raise TypeError("The mean function must have type `probnum.Function`.")

if not isinstance(cov, kernels.Kernel):
raise TypeError(
"The covariance functions must be implemented as a " "`Kernel`."
)

if len(mean.input_shape) > 1:
raise ValueError(
"The mean function must have input shape `()` or `(D_in,)`."
)

if len(mean.output_shape) > 1:
raise ValueError(
"The mean function must have output shape `()` or `(D_out,)`."
)

if mean.input_shape != cov.input_shape:
raise ValueError(
"The mean and covariance functions must have the same input shapes "
f"(`mean.input_shape` is {mean.input_shape} and `cov.input_shape` is "
f"{cov.input_shape})."
)

if 2 * mean.output_shape != cov.shape:
raise ValueError(
"The shape of the `Kernel` must be a tuple of the form "
"`(output_shape, output_shape)`, where `output_shape` is the output "
"shape of the mean function."
)

self._mean = mean
self._cov = cov

super().__init__(
input_shape=mean.input_shape,
output_shape=mean.output_shape,
dtype=np.dtype(np.float_),
mean=mean,
cov=cov,
)

def __call__(self, args: _InputType) -> randvars.Normal:
def __call__(self, args: ArrayLike) -> randvars.Normal:
return randvars.Normal(
mean=np.array(self.mean(args), copy=False), cov=self.cov.matrix(args)
mean=np.array(self.mean(args), copy=False), # pylint: disable=not-callable
cov=self.cov.matrix(args),
)

@property
def mean(self) -> _function.Function:
return self._mean

@property
def cov(self) -> kernels.Kernel:
return self._cov

def _sample_at_input(
self,
rng: np.random.Generator,
args: _InputType,
size: ShapeLike = (),
) -> _OutputType:
gaussian_rv = self.__call__(args)
return gaussian_rv.sample(rng=rng, size=size)

def push_forward(
self,
args: _InputType,
base_measure: Type[randvars.RandomVariable],
sample: np.ndarray,
) -> np.ndarray:
raise NotImplementedError
165 changes: 114 additions & 51 deletions src/probnum/randprocs/_random_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Random Processes."""

import abc
from typing import Callable, Generic, Type, TypeVar, Union
from typing import Callable, Generic, Optional, Type, TypeVar, Union

import numpy as np

Expand Down Expand Up @@ -30,6 +30,10 @@ class RandomProcess(Generic[_InputType, _OutputType], abc.ABC):
dtype :
Data type of the random process evaluated at an input. If ``object`` will be
converted to ``numpy.dtype``.
mean :
Mean function of the random process.
cov :
Covariance function of the random process.
See Also
--------
Expand All @@ -50,6 +54,8 @@ def __init__(
input_shape: ShapeLike,
output_shape: ShapeLike,
dtype: DTypeLike,
mean: Optional[_function.Function] = None,
cov: Optional[kernels.Kernel] = None,
):
self._input_shape = _utils.as_shape(input_shape)
self._input_ndim = len(self._input_shape)
Expand All @@ -65,6 +71,72 @@ def __init__(

self._dtype = np.dtype(dtype)

# Mean function
if mean is not None:
if not isinstance(mean, _function.Function):
raise TypeError("The mean function must have type `probnum.Function`.")

if mean.input_shape != self._input_shape:
raise ValueError(
f"The mean function must have the same `input_shape` as the random "
f"process (`{mean.input_shape}` != `{self._input_shape}`)."
)

if mean.output_shape != self._output_shape:
raise ValueError(
f"The mean function must have the same `output_shape` as the "
f"random process (`{mean.output_shape}` != `{self._output_shape}`)."
)

self._mean = mean

# Covariance function
if cov is not None:
if not isinstance(cov, kernels.Kernel):
raise TypeError(
"The covariance functions must be implemented as a " "`Kernel`."
)

if cov.input_shape != self._input_shape:
raise ValueError(
f"The covariance function must have the same `input_shape` as the "
f"random process (`{cov.input_shape}` != `{self._input_shape}`)."
)

if cov.output_shape != 2 * self._output_shape:
raise ValueError(
f"The `output_shape` of the covariance function must be given by "
f"`2 * self.output_shape` (`{cov.output_shape}` != "
f"`{2 * self._output_shape}`)."
)

self._cov = cov

@property
def input_shape(self) -> ShapeType:
"""Shape of inputs to the random process."""
return self._input_shape

@property
def input_ndim(self) -> int:
"""Syntactic sugar for ``len(input_shape)``."""
return self._input_ndim

@property
def output_shape(self) -> ShapeType:
"""Shape of the random process evaluated at an input."""
return self._output_shape

@property
def output_ndim(self) -> int:
"""Syntactic sugar for ``len(output_shape)``."""
return self._output_ndim

@property
def dtype(self) -> np.dtype:
"""Data type of (elements of) the random process evaluated at an input."""
return self._dtype

def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} with "
Expand All @@ -79,7 +151,7 @@ def __call__(self, args: _InputType) -> randvars.RandomVariable[_OutputType]:
Parameters
----------
args
*shape=* ``batch_shape + `` :attr:`input_shape` -- (Batch of) input(s) at
*shape=* ``batch_shape +`` :attr:`input_shape` -- (Batch of) input(s) at
which to evaluate the random process. Currently, we require ``batch_shape``
to have at most one dimension.
Expand All @@ -89,46 +161,45 @@ def __call__(self, args: _InputType) -> randvars.RandomVariable[_OutputType]:
*shape=* ``batch_shape +`` :attr:`output_shape` -- Random process evaluated
at the input(s).
"""
raise NotImplementedError

@property
def input_shape(self) -> ShapeType:
"""Shape of inputs to the random process."""
return self._input_shape

@property
def output_shape(self) -> ShapeType:
"""Shape of the random process evaluated at an input."""
return self._output_shape

@property
def dtype(self) -> np.dtype:
"""Data type of (elements of) the random process evaluated at an input."""
return self._dtype

def marginal(self, args: _InputType) -> randvars._RandomVariableList:
"""Batch of random variables defining the marginal distributions at the inputs.
Parameters
----------
args
*shape=* ``batch_shape + `` :attr:`input_shape` -- (Batch of) input(s) at
*shape=* ``batch_shape +`` :attr:`input_shape` -- (Batch of) input(s) at
which to evaluate the random process. Currently, we require ``batch_shape``
to have at most one dimension.
"""
# return self.__call__(args).marginal()
raise NotImplementedError

@property
def mean(self) -> _function.Function:
r"""Mean function :math:`m(x) = \mathbb{E}[f(x)]` of the random process"""
raise NotImplementedError
r"""Mean function :math:`m(x) := \mathbb{E}[f(x)]` of the random process."""
if self._mean is None:
raise NotImplementedError

return self._mean

@property
def cov(self) -> kernels.Kernel:
r"""Covariance function :math:`k(x_0, x_1) = \mathbb{E}[(f(x_0) - \mathbb{E}[
f(x_0)])(f(x_0) - \mathbb{E}[f(x_0)])^\top]` of the random process."""
raise NotImplementedError
r"""Covariance function :math:`k(x_0, x_1)` of the random process.
.. math::
:nowrap:
\begin{equation}
k(x_0, x_1) := \mathbb{E} \left[
(f(x_0) - \mathbb{E}[f(x_0)])
(f(x_1) - \mathbb{E}[f(x_1)])^\top
\right]
\end{equation}
"""
if self._cov is None:
raise NotImplementedError

return self._cov

def var(self, args: _InputType) -> _OutputType:
"""Variance function.
Expand All @@ -139,53 +210,45 @@ def var(self, args: _InputType) -> _OutputType:
Parameters
----------
args
*shape=* ``batch_shape + input_shape_bcastable`` -- (Batch of) input(s) at
which to evaluate the variance function. ``input_shape_bcastable`` must be a
shape that can be broadcast to :attr:`input_shape`.
*shape=* ``batch_shape +`` :attr:`input_shape` -- (Batch of) input(s) at
which to evaluate the variance function.
Returns
-------
_OutputType
*shape=* ``batch_shape`` or ``output_shape[:1] + batch_shape`` -- Variance
of the process at ``args``.
*shape=* ``batch_shape +`` :attr:`output_shape` -- Variance of the process
at ``args``.
"""
try:
var = self.cov(args, None)
except NotImplementedError as exc:
raise NotImplementedError from exc
pointwise_covs = self.cov(args, None) # pylint: disable=not-callable

assert (
var.shape
== 2 * self._output_shape + args.shape[: args.ndim - self._input_ndim]
pointwise_covs.shape
== args.shape[: args.ndim - self._input_ndim] + 2 * self._output_shape
)

if self._output_ndim == 0:
return var
return pointwise_covs

assert self._output_ndim == 1

return np.diagonal(var, axis1=0, axis2=1)
return np.diagonal(pointwise_covs, axis1=-2, axis2=-1)

def std(self, args: _InputType) -> _OutputType:
"""Standard deviation function.
Parameters
----------
args
*shape=* ``batch_shape + input_shape_bcastable`` -- (Batch of) input(s) at
which to evaluate the standard deviation function. ``input_shape_bcastable``
must be a shape that can be broadcast to :attr:`input_shape`.
*shape=* ``batch_shape +`` :attr:`input_shape` -- (Batch of) input(s) at
which to evaluate the standard deviation function.
Returns
-------
_OutputType
*shape=* ``batch_shape`` or ``output_shape[:1] + batch_shape`` -- Standard
deviation of the process at ``args``.
*shape=* ``batch_shape +`` :attr:`output_shape` -- Standard deviation of the
process at ``args``.
"""
try:
return np.sqrt(self.var(args=args))
except NotImplementedError as exc:
raise NotImplementedError from exc
return np.sqrt(self.var(args=args))

def push_forward(
self,
Expand All @@ -205,7 +268,7 @@ def push_forward(
base_measure
Base measure. Given as a type of random variable.
sample
*shape=* ``sample_shape + `` :attr:`input_shape` -- (Batch of) input(s) at
*shape=* ``sample_shape +`` :attr:`input_shape` -- (Batch of) input(s) at
which to evaluate the random process. Currently, we require ``sample_shape``
to have at most one dimension.
"""
Expand All @@ -228,7 +291,7 @@ def sample(
rng
Random number generator.
args
*shape=* ``size + `` :attr:`input_shape` -- (Batch of) input(s) at
*shape=* ``size +`` :attr:`input_shape` -- (Batch of) input(s) at
which the sample paths will be evaluated. Currently, we require
``size`` to have at most one dimension. If ``None``, sample paths,
i.e. callables are returned.
Expand Down Expand Up @@ -257,10 +320,10 @@ def _sample_at_input(
rng
Random number generator.
args
*shape=* ``size + `` :attr:`input_shape` -- (Batch of) input(s) at
*shape=* ``size +`` :attr:`input_shape` -- (Batch of) input(s) at
which the sample paths will be evaluated. Currently, we require
``size`` to have at most one dimension.
size
Size of the sample.
"""
raise NotImplementedError
return self(args).sample(rng, size=size)
Loading

0 comments on commit 7af755a

Please sign in to comment.