Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: improve docstrings of function module #428

Merged
merged 5 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"lambdification",
"lambdified",
"lambdify",
"lambdifygenerated",
"linestyle",
"linewidth",
"linkcheck",
Expand Down
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ rst-roles =
mod
ref
rst-directives =
automethod
deprecated
envvar
exception
Expand Down
6 changes: 0 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ def fetch_logo(url: str, output_path: str) -> None:
"members": True,
"undoc-members": True,
"show-inheritance": True,
"special-members": ", ".join(
[
"__call__",
"__eq__",
]
),
}
autodoc_member_order = "bysource"
autodoc_type_aliases = {
Expand Down
4 changes: 0 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ Upcoming features <https://github.com/ComPWA/tensorwaves/milestones?direction=as
Help developing <https://compwa-org.rtfd.io/en/stable/develop.html>
```

- {ref}`Python API <modindex>`
- {ref}`General Index <genindex>`
- {ref}`Search <search>`

```{toctree}
---
caption: Related projects
Expand Down
4 changes: 2 additions & 2 deletions docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {meth}`.ParametrizedBackendFunction.__call__` takes a {class}`dict` of variable names (here, `\"x\"` only) to the value(s) that should be used in their place."
"The {meth}`.ParametrizedFunction.__call__` takes a {class}`dict` of variable names (here, `\"x\"` only) to the value(s) that should be used in their place."
]
},
{
Expand Down Expand Up @@ -503,7 +503,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"First of all, we need to randomly generate values of $x$. In this simple, 1-dimensional example, we could just use a random generator like {class}`numpy.random.Generator` feed its output to the {meth}`.ParametrizedBackendFunction.__call__`. Generally, though, we want to cover $n$-dimensional cases. The class {class}`.NumpyDomainGenerator` allows us to generate such a **uniform** distribution for each variable within a certain range. It requires a {class}`.RealNumberGenerator` (here we use {class}`.NumpyUniformRNG`) and it also requires us to define boundaries for each variable in the resulting {obj}`.DataSample`."
"First of all, we need to randomly generate values of $x$. In this simple, 1-dimensional example, we could just use a random generator like {class}`numpy.random.Generator` feed its output to the {meth}`.ParametrizedFunction.__call__`. Generally, though, we want to cover $n$-dimensional cases. The class {class}`.NumpyDomainGenerator` allows us to generate such a **uniform** distribution for each variable within a certain range. It requires a {class}`.RealNumberGenerator` (here we use {class}`.NumpyUniformRNG`) and it also requires us to define boundaries for each variable in the resulting {obj}`.DataSample`."
]
},
{
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ addopts =
--ignore=docs/conf.py
-k "not benchmark"
-m "not slow"
doctest_optionflags = NORMALIZE_WHITESPACE
filterwarnings =
error
ignore:.* is deprecated and will be removed in Pillow 10.*:DeprecationWarning
Expand Down
24 changes: 20 additions & 4 deletions src/tensorwaves/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def _to_tuple(argument_order: Iterable[str]) -> Tuple[str, ...]:
class PositionalArgumentFunction(Function):
"""Wrapper around a function with positional arguments.

This class provides a :meth:`__call__` that can take a `.DataSample` for a
function with `positional arguments
This class provides a :meth:`~.Function.__call__` that can take a
`.DataSample` for a function with `positional arguments
<https://docs.python.org/3/glossary.html#term-positional-argument>`_. Its
:attr:`argument_order` redirect the keys in the `.DataSample` to the
argument positions in its underlying :attr:`function`.

.. seealso:: :func:`.create_function`
"""

function: Callable[..., np.ndarray] = field(validator=_validate_arguments)
Expand All @@ -84,7 +86,10 @@ def __call__(self, data: DataSample) -> np.ndarray:


class ParametrizedBackendFunction(ParametrizedFunction):
"""Implements `.ParametrizedFunction` for a specific computational back-end."""
"""Implements `.ParametrizedFunction` for a specific computational back-end.

.. seealso:: :func:`.create_parametrized_function`
"""

def __init__(
self,
Expand Down Expand Up @@ -126,7 +131,18 @@ def update_parameters(


def get_source_code(function: Function) -> str:
"""Get the backend source code used to compile this function."""
"""Get the backend source code used to compile this function.

>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_function
>>> x, y = sp.symbols("x y")
>>> expr = x**2 + y**2
>>> func = create_function(expr, backend="jax", use_cse=False)
>>> src = get_source_code(func)
>>> print(src)
def _lambdifygenerated(x, y):
return x**2 + y**2
"""
if isinstance(
function, (PositionalArgumentFunction, ParametrizedBackendFunction)
):
Expand Down
62 changes: 60 additions & 2 deletions src/tensorwaves/function/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,36 @@
def create_function(
expression: "sp.Expr",
backend: str,
max_complexity: Optional[int] = None,
use_cse: bool = True,
max_complexity: Optional[int] = None,
) -> PositionalArgumentFunction:
"""Convert a SymPy expression to a computational function.

Args:
expression: The SymPy expression that you want to
`~sympy.utilities.lambdify.lambdify`. Its
`~sympy.core.basic.Basic.free_symbols` become arguments to the
resulting `.PositionalArgumentFunction`.

backend: The computational backend in which to express the function.
use_cse: Identify common sub-expressions in the function. This usually
makes the function faster and speeds up lambdification.

max_complexity: See :ref:`usage/faster-lambdify:Specifying complexity`
and :doc:`compwa-org:report/002`.

Example:
>>> import numpy as np
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_function
>>> x, y = sp.symbols("x y")
>>> expression = x**2 + y**2
>>> function = create_function(expression, backend="jax")
>>> array = np.linspace(0, 3, num=4)
>>> data = {"x": array, "y": array}
>>> function(data)
DeviceArray([ 0., 2., 8., 18.], dtype=float64)
"""
free_symbols = _get_free_symbols(expression)
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
lambdified_function = _lambdify_normal_or_fast(
Expand All @@ -61,9 +88,40 @@ def create_parametrized_function(
expression: "sp.Expr",
parameters: Mapping["sp.Symbol", ParameterValue],
backend: str,
max_complexity: Optional[int] = None,
use_cse: bool = True,
max_complexity: Optional[int] = None,
) -> ParametrizedBackendFunction:
"""Convert a SymPy expression to a parametrized function.

This is an extended version of :func:`create_function`, which allows one to
identify certain symbols in the expression as parameters.

Args:
expression: See :func:`create_function`.
parameters: The symbols in the expression that are be identified as
`~.ParametrizedFunction.parameters` in the returned
`.ParametrizedBackendFunction`.
backend: See :func:`create_function`.
use_cse: See :func:`create_function`.
max_complexity: See :func:`create_function`.

Example:
>>> import numpy as np
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_parametrized_function
>>> a, b, x, y = sp.symbols("a b x y")
>>> expression = a * x**2 + b * y**2
>>> function = create_parametrized_function(
... expression,
... parameters={a: -1, b: 2.5},
... backend="jax",
... )
>>> array = np.linspace(0, 1, num=5)
>>> data = {"x": array, "y": array}
>>> function.update_parameters({"b": 1})
>>> function(data)
DeviceArray([0., 0., 0., 0., 0.], dtype=float64)
"""
free_symbols = _get_free_symbols(expression)
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
lambdified_function = _lambdify_normal_or_fast(
Expand Down
14 changes: 11 additions & 3 deletions src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class Function(ABC, Generic[InputType, OutputType]):
`.OutputType` values (co-domain) for a given set of `.InputType` values
(domain). Examples of `Function` are `ParametrizedFunction`, `Estimator`
and `DataTransformer`.

.. automethod:: __call__
"""

@abstractmethod
Expand All @@ -55,9 +57,11 @@ class ParametrizedFunction(Function[DataSample, np.ndarray]):
A `ParametrizedFunction` identifies certain variables in a mathematical
expression as **parameters**. Remaining variables are considered **domain
variables**. Domain variables are the argument of the evaluation (see
:func:`~Function.__call__`), while the parameters are controlled via
:attr:`parameters` (getter) and :meth:`update_parameters` (setter). This
mechanism is especially important for an `Estimator`.
:func:`~ParametrizedFunction.__call__`), while the parameters are
controlled via :attr:`parameters` (getter) and :meth:`update_parameters`
(setter). This mechanism is especially important for an `Estimator`.

.. automethod:: __call__
"""

@property
Expand Down Expand Up @@ -85,6 +89,8 @@ class Estimator(Function[Mapping[str, ParameterValue], float]):

See the :mod:`.estimator` module for different implementations of this
interface.

.. automethod:: __call__
"""

def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
Expand Down Expand Up @@ -210,6 +216,8 @@ class RealNumberGenerator(ABC):
"""Abstract class for generating real numbers within a certain range.

Implementations can be found in the `tensorwaves.data` module.

.. automethod:: __call__
"""

@abstractmethod
Expand Down