Skip to content

Commit

Permalink
Reorder algebra fallbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 31, 2022
1 parent df033ab commit 67c5faf
Showing 1 changed file with 56 additions and 56 deletions.
112 changes: 56 additions & 56 deletions src/probnum/_function/_algebra_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,62 +11,6 @@
from ._function import Function


class ScaledFunction(Function):
r"""Function multiplied pointwise with a scalar.
Given a function :math:`f \colon \mathbb{R}^n \to \mathbb{R}^m` and a scalar
:math:`\alpha \in \mathbb{R}`, this defines a new function
.. math::
\alpha f \colon \mathbb{R}^n \to \mathbb{R}^m,
x \masto (\alpha f)(x) = \alpha f(x).
Parameters
----------
function
The function :math:`f`.
scalar
The scalar :math:`\alpha`.
"""

def __init__(self, function: Function, scalar: ScalarLike):
if not isinstance(function, Function):
raise TypeError(
"The function to be scaled must be an object of type `Function`"
)

self._function = function
self._scalar = utils.as_numpy_scalar(scalar)

super().__init__(
input_shape=self._function.input_shape,
output_shape=self._function.output_shape,
)

@property
def function(self) -> Function:
r"""The function :math:`f`."""
return self._function

@property
def scalar(self) -> ScalarType:
r"""The scalar :math:`\alpha`."""
return self._scalar

def _evaluate(self, x: np.ndarray) -> np.ndarray:
return self._scalar * self._function(x)

@functools.singledispatchmethod
def __rmul__(self, other):
if np.ndim(other) == 0:
return ScaledFunction(
function=self._function,
scalar=np.asarray(other) * self._scalar,
)

return super().__rmul__(other)


class SumFunction(Function):
r"""Pointwise sum of :class:`Function`s.
Expand Down Expand Up @@ -127,3 +71,59 @@ def __add__(self, other):
@functools.singledispatchmethod
def __sub__(self, other):
return super().__sub__(other)


class ScaledFunction(Function):
r"""Function multiplied pointwise with a scalar.
Given a function :math:`f \colon \mathbb{R}^n \to \mathbb{R}^m` and a scalar
:math:`\alpha \in \mathbb{R}`, this defines a new function
.. math::
\alpha f \colon \mathbb{R}^n \to \mathbb{R}^m,
x \masto (\alpha f)(x) = \alpha f(x).
Parameters
----------
function
The function :math:`f`.
scalar
The scalar :math:`\alpha`.
"""

def __init__(self, function: Function, scalar: ScalarLike):
if not isinstance(function, Function):
raise TypeError(
"The function to be scaled must be an object of type `Function`"
)

self._function = function
self._scalar = utils.as_numpy_scalar(scalar)

super().__init__(
input_shape=self._function.input_shape,
output_shape=self._function.output_shape,
)

@property
def function(self) -> Function:
r"""The function :math:`f`."""
return self._function

@property
def scalar(self) -> ScalarType:
r"""The scalar :math:`\alpha`."""
return self._scalar

def _evaluate(self, x: np.ndarray) -> np.ndarray:
return self._scalar * self._function(x)

@functools.singledispatchmethod
def __rmul__(self, other):
if np.ndim(other) == 0:
return ScaledFunction(
function=self._function,
scalar=np.asarray(other) * self._scalar,
)

return super().__rmul__(other)

0 comments on commit 67c5faf

Please sign in to comment.