Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Algebraic Operations on Functions #725

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ API Reference
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.filtsmooth` | Bayesian filtering and smoothing. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.functions` | Callables with in- and output shape information. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linalg` | Probabilistic numerical linear algebra. |
+-------------------------------------------------+--------------------------------------------------------------+
| :mod:`~probnum.linops` | Finite-dimensional linear operators. |
Expand All @@ -39,6 +41,7 @@ API Reference
api/config
api/diffeq
api/filtsmooth
api/functions
api/linalg
api/linops
api/problems
Expand Down
7 changes: 7 additions & 0 deletions docs/source/api/functions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
*****************
probnum.functions
*****************

.. automodapi:: probnum.functions
:no-heading:
:headings: "="
1 change: 0 additions & 1 deletion docs/source/api/randprocs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ probnum.randprocs
:hidden:

randprocs/markov
randprocs/mean_fns
randprocs/kernels
7 changes: 0 additions & 7 deletions docs/source/api/randprocs/mean_fns.rst

This file was deleted.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ exclude_lines = [
# Don't complain if non-runnable code isn't run:
'if 0:',
'if __name__ == .__main__.:',
# Don't complain if operator's are not overloaded
'return NotImplemented'
]

################################################################################
Expand Down
6 changes: 1 addition & 5 deletions src/probnum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import (
diffeq,
filtsmooth,
functions,
linalg,
linops,
problems,
Expand All @@ -34,23 +35,18 @@
randvars,
utils,
)
from ._function import Function, LambdaFunction
from ._version import version as __version__
from .randvars import asrandvar

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"asrandvar",
"Function",
"LambdaFunction",
"ProbabilisticNumericalMethod",
"StoppingCriterion",
"LambdaStoppingCriterion",
]

# Set correct module paths. Corrects links and module paths in documentation.
Function.__module__ = "probnum"
LambdaFunction.__module__ = "probnum"
ProbabilisticNumericalMethod.__module__ = "probnum"
StoppingCriterion.__module__ = "probnum"
LambdaStoppingCriterion.__module__ = "probnum"
6 changes: 6 additions & 0 deletions src/probnum/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Callables with in- and output shape information supporting algebraic operations."""

from . import _algebra
from ._algebra_fallbacks import ScaledFunction, SumFunction
from ._function import Function, LambdaFunction
from ._zero import Zero
69 changes: 69 additions & 0 deletions src/probnum/functions/_algebra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
r"""Algebraic operations on :class:`Function`\ s."""

from ._algebra_fallbacks import SumFunction
from ._function import Function
from ._zero import Zero

############
# Function #
############


@Function.__add__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(self, other)


@Function.__add__.register # pylint: disable=no-member
def _(self, other: SumFunction) -> SumFunction:
return SumFunction(self, *other.summands)


@Function.__add__.register # pylint: disable=no-member
def _(self, other: Zero) -> Function: # pylint: disable=unused-argument
return self


@Function.__sub__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(self, -other)


@Function.__sub__.register # pylint: disable=no-member
def _(self, other: Zero) -> Function: # pylint: disable=unused-argument
return self


###############
# SumFunction #
###############


@SumFunction.__add__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(*self.summands, other)


@SumFunction.__add__.register # pylint: disable=no-member
def _(self, other: SumFunction) -> SumFunction:
return SumFunction(*self.summands, *other.summands)


@SumFunction.__sub__.register # pylint: disable=no-member
def _(self, other: Function) -> SumFunction:
return SumFunction(*self.summands, -other)


########
# Zero #
########


@Zero.__add__.register # pylint: disable=no-member
def _(self, other: Function) -> Function: # pylint: disable=unused-argument
return other


@Zero.__sub__.register # pylint: disable=no-member
def _(self, other: Function) -> Function: # pylint: disable=unused-argument
return -other
141 changes: 141 additions & 0 deletions src/probnum/functions/_algebra_fallbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
r"""Fallback implementation for algebraic operations on :class:`Function`\ s."""

from __future__ import annotations

import functools
import operator

import numpy as np

from probnum import utils
from probnum.typing import ScalarLike, ScalarType

from ._function import Function


class SumFunction(Function):
r"""Pointwise sum of :class:`Function`\ s.

Given functions :math:`f_1, \dotsc, f_n \colon \mathbb{R}^n \to \mathbb{R}^m`, this
defines a new function

.. math::
\sum_{i = 1}^n f_i \colon \mathbb{R}^n \to \mathbb{R}^m,
x \mapsto \sum_{i = 1}^n f_i(x).

Parameters
----------
*summands
The functions :math:`f_1, \dotsc, f_n`.
"""

def __init__(self, *summands: Function) -> None:
if not all(isinstance(summand, Function) for summand in summands):
raise TypeError(
"The functions to be added must be objects of type `Function`."
)

if not all(
summand.input_shape == summands[0].input_shape for summand in summands
):
raise ValueError(
"The functions to be added must all have the same input shape."
)

if not all(
summand.output_shape == summands[0].output_shape for summand in summands
):
raise ValueError(
"The functions to be added must all have the same output shape."
)

self._summands = summands

super().__init__(
input_shape=summands[0].input_shape,
output_shape=summands[0].output_shape,
)

@property
def summands(self) -> tuple[SumFunction, ...]:
r"""The functions :math:`f_1, \dotsc, f_n` to be added."""
return self._summands

def _evaluate(self, x: np.ndarray) -> np.ndarray:
return functools.reduce(
operator.add, (summand(x) for summand in self._summands)
)

@functools.singledispatchmethod
def __add__(self, other):
return super().__add__(other)

@functools.singledispatchmethod
def __sub__(self, other):
return super().__sub__(other)


class ScaledFunction(Function):
r"""Function multiplied pointwise with a scalar.

Given a function :math:`f \colon \mathbb{R}^n \to \mathbb{R}^m` and a scalar
:math:`\alpha \in \mathbb{R}`, this defines a new function

.. math::
\alpha f \colon \mathbb{R}^n \to \mathbb{R}^m,
x \mapsto (\alpha f)(x) = \alpha f(x).

Parameters
----------
function
The function :math:`f`.
scalar
The scalar :math:`\alpha`.
"""

def __init__(self, function: Function, scalar: ScalarLike):
if not isinstance(function, Function):
raise TypeError(
"The function to be scaled must be an object of type `Function`."
)

self._function = function
self._scalar = utils.as_numpy_scalar(scalar)

super().__init__(
input_shape=self._function.input_shape,
output_shape=self._function.output_shape,
)

@property
def function(self) -> Function:
r"""The function :math:`f`."""
return self._function

@property
def scalar(self) -> ScalarType:
r"""The scalar :math:`\alpha`."""
return self._scalar

def _evaluate(self, x: np.ndarray) -> np.ndarray:
return self._scalar * self._function(x)

@functools.singledispatchmethod
def __mul__(self, other):
if np.ndim(other) == 0:
return ScaledFunction(
function=self._function,
scalar=self._scalar * np.asarray(other),
)

return super().__mul__(other)

@functools.singledispatchmethod
def __rmul__(self, other):
if np.ndim(other) == 0:
return ScaledFunction(
function=self._function,
scalar=np.asarray(other) * self._scalar,
)

return super().__rmul__(other)
44 changes: 40 additions & 4 deletions src/probnum/_function.py → src/probnum/functions/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

import abc
import functools
from typing import Callable

import numpy as np

from . import utils
from .typing import ArrayLike, ShapeLike, ShapeType
from probnum import utils
from probnum.typing import ArrayLike, ShapeLike, ShapeType


class Function(abc.ABC):
Expand All @@ -17,6 +18,8 @@ class Function(abc.ABC):
This class represents a, uni- or multivariate, scalar- or tensor-valued,
mathematical function. Hence, the call method should not have any observable
side-effects.
Instances of this class can be added and multiplied by a scalar, which means that
they are elements of a vector space.

Parameters
----------
Expand All @@ -29,7 +32,7 @@ class Function(abc.ABC):
See Also
--------
LambdaFunction : Define a :class:`Function` from an anonymous function.
~probnum.randprocs.mean_fns.Zero : Zero mean function of a random process.
~probnum.functions.Zero : Zero function.
"""

def __init__(self, input_shape: ShapeLike, output_shape: ShapeLike = ()) -> None:
Expand Down Expand Up @@ -112,6 +115,39 @@ def __call__(self, x: ArrayLike) -> np.ndarray:
def _evaluate(self, x: np.ndarray) -> np.ndarray:
pass

def __neg__(self):
return -1.0 * self

@functools.singledispatchmethod
def __add__(self, other):
return NotImplemented

@functools.singledispatchmethod
def __sub__(self, other):
return NotImplemented

@functools.singledispatchmethod
def __mul__(self, other):
if np.ndim(other) == 0:
from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel
ScaledFunction,
)

return ScaledFunction(function=self, scalar=other)

return NotImplemented

@functools.singledispatchmethod
def __rmul__(self, other):
if np.ndim(other) == 0:
from ._algebra_fallbacks import ( # pylint: disable=import-outside-toplevel
ScaledFunction,
)

return ScaledFunction(function=self, scalar=other)

return NotImplemented


class LambdaFunction(Function):
"""Define a :class:`Function` from a given :class:`callable`.
Expand All @@ -131,7 +167,7 @@ class LambdaFunction(Function):
Examples
--------
>>> import numpy as np
>>> from probnum import LambdaFunction
>>> from probnum.functions import LambdaFunction
>>> fn = LambdaFunction(fn=lambda x: 2 * x + 1, input_shape=(2,), output_shape=(2,))
>>> fn(np.array([[1, 2], [4, 5]]))
array([[ 3, 5],
Expand Down
Loading