Skip to content

Commit

Permalink
Merge pull request #170 from scipp/compute-series
Browse files Browse the repository at this point in the history
Add `compute_mapped`
  • Loading branch information
SimonHeybrock authored Jun 21, 2024
2 parents 35dbdef + f6d8058 commit bd8e6f3
Show file tree
Hide file tree
Showing 14 changed files with 349 additions and 26 deletions.
11 changes: 11 additions & 0 deletions docs/api-reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
HandleAsComputeTimeException
```

## Top-level functions

```{eval-rst}
.. autosummary::
:toctree: ../generated/functions
:recursive:
compute_mapped
get_mapped_node_names
```

## Exceptions

```{eval-rst}
Expand Down
29 changes: 21 additions & 8 deletions docs/user-guide/parameter-tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
"Equivalently the table could be represented as a `dict`, where each key corresponds to a column header and each value is a list of values for that column, i.e., `{Filename: filenames}`.\n",
"Specifying an index is currently not possible in this case, and it will default to a range index.\n",
"\n",
"We can now use [Pipeline.map](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.map) to create a modified pipeline that processes each row in the parameter table:"
"We can now use [Pipeline.map](../generated/classes/sciline.Pipeline.html#sciline.Pipeline.map) to create a modified pipeline that processes each row in the parameter table:"
]
},
{
Expand All @@ -146,8 +146,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can compute `Result` for each index in the parameter table.\n",
"Currently there is no convenient way of accessing these, instead we manually define the target nodes to compute:"
"We can use the [compute_mapped](../generated/functions/sciline.compute_mapped.html) function to compute `Result` for each index in the parameter table:"
]
},
{
Expand All @@ -156,10 +155,8 @@
"metadata": {},
"outputs": [],
"source": [
"from cyclebane.graph import NodeName, IndexValues\n",
"\n",
"targets = [NodeName(Result, IndexValues(('run_id',), (i,))) for i in run_ids]\n",
"pipeline.compute(targets)"
"results = sciline.compute_mapped(pipeline, Result)\n",
"pd.DataFrame(results) # DataFrame for HTML rendering"
]
},
{
Expand All @@ -168,7 +165,22 @@
"source": [
"Note the use of the `run_id` index.\n",
"If the index axis of the DataFrame has no name then a default of `dim_0`, `dim_1`, etc. is used.\n",
"We can also visualize the task graph for computing the series of `Result` values:"
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"**Note**\n",
"\n",
"[compute_mapped](../generated/functions/sciline.compute_mapped.html) depends on Pandas, which is not a dependency of Sciline and must be installed separately, e.g., using pip:\n",
"\n",
"```bash\n",
"pip install pandas\n",
"```\n",
"\n",
"</div>\n",
"\n",
"We can also visualize the task graph for computing the series of `Result` values.\n",
"For this, we need to get all the node names derived from `Result` via the `map` operation.\n",
"The [get_mapped_node_names](../generated/functions/sciline.get_mapped_node_names.html) function can be used to get a `pandas.Series` of these node names, which we can then visualize:"
]
},
{
Expand All @@ -177,6 +189,7 @@
"metadata": {},
"outputs": [],
"source": [
"targets = sciline.get_mapped_node_names(pipeline, Result)\n",
"pipeline.visualize(targets)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# will not be touched by ``make_base.py``
# --- END OF CUSTOM SECTION ---
# The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY!
cyclebane >= 24.06.0
cyclebane>=24.06.0
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:44b03b700447e95874de7cd4d8b11558c0c972e5
# SHA1:1b4246f703135629f3fb69e65829cfb99abca695
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand Down
1 change: 1 addition & 0 deletions requirements/basetest.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
graphviz
jsonschema
numpy
pandas
pytest
18 changes: 15 additions & 3 deletions requirements/basetest.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:f7d11f6aab1600c37d922ffb857a414368b800cd
# SHA1:fa9ebd4f58fe57db20baa224ad46fac634f3f046
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand All @@ -19,14 +19,22 @@ jsonschema==4.22.0
# via -r basetest.in
jsonschema-specifications==2023.12.1
# via jsonschema
numpy==1.26.4
# via -r basetest.in
numpy==2.0.0
# via
# -r basetest.in
# pandas
packaging==24.1
# via pytest
pandas==2.2.2
# via -r basetest.in
pluggy==1.5.0
# via pytest
pytest==8.2.2
# via -r basetest.in
python-dateutil==2.9.0.post0
# via pandas
pytz==2024.1
# via pandas
referencing==0.35.1
# via
# jsonschema
Expand All @@ -35,5 +43,9 @@ rpds-py==0.18.1
# via
# jsonschema
# referencing
six==1.16.0
# via python-dateutil
tomli==2.0.1
# via pytest
tzdata==2024.1
# via pandas
4 changes: 2 additions & 2 deletions requirements/ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ colorama==0.4.6
# via tox
distlib==0.3.8
# via virtualenv
filelock==3.14.0
filelock==3.15.1
# via
# tox
# virtualenv
Expand Down Expand Up @@ -50,7 +50,7 @@ tomli==2.0.1
# tox
tox==4.15.1
# via -r ci.in
urllib3==2.2.1
urllib3==2.2.2
# via requests
virtualenv==20.26.2
# via tox
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ prometheus-client==0.20.0
# via jupyter-server
pycparser==2.22
# via cffi
pydantic==2.7.3
pydantic==2.7.4
# via copier
pydantic-core==2.18.4
# via pydantic
Expand Down
6 changes: 3 additions & 3 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ exceptiongroup==1.2.1
# via ipython
executing==2.0.1
# via stack-data
fastjsonschema==2.19.1
fastjsonschema==2.20.0
# via nbformat
graphviz==0.20.3
# via -r docs.in
Expand Down Expand Up @@ -120,7 +120,7 @@ nbsphinx==0.9.4
# via -r docs.in
nest-asyncio==1.6.0
# via ipykernel
numpy==1.26.4
numpy==2.0.0
# via pandas
packaging==24.1
# via
Expand Down Expand Up @@ -241,7 +241,7 @@ typing-extensions==4.12.2
# pydata-sphinx-theme
tzdata==2024.1
# via pandas
urllib3==2.2.1
urllib3==2.2.2
# via requests
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
2 changes: 1 addition & 1 deletion requirements/static.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cfgv==3.4.0
# via pre-commit
distlib==0.3.8
# via virtualenv
filelock==3.14.0
filelock==3.15.1
# via virtualenv
identify==2.5.36
# via pre-commit
Expand Down
2 changes: 1 addition & 1 deletion requirements/test-dask.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ click==8.1.7
# via dask
cloudpickle==3.0.0
# via dask
dask==2024.5.2
dask==2024.6.0
# via -r test-dask.in
fsspec==2024.6.0
# via dask
Expand Down
4 changes: 3 additions & 1 deletion src/sciline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
HandleAsComputeTimeException,
UnsatisfiedRequirement,
)
from .pipeline import Pipeline
from .pipeline import Pipeline, compute_mapped, get_mapped_node_names
from .task_graph import TaskGraph

__all__ = [
Expand All @@ -30,6 +30,8 @@
"UnsatisfiedRequirement",
"HandleAsBuildTimeException",
"HandleAsComputeTimeException",
"compute_mapped",
"get_mapped_node_names",
]

del importlib
111 changes: 108 additions & 3 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

from collections.abc import Callable, Iterable
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import Any, TypeVar, get_args, get_type_hints, overload
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload

from ._provider import Provider, ToProvider
from .data_graph import DataGraph, to_task_graph
Expand All @@ -15,6 +15,11 @@
from .task_graph import TaskGraph
from .typing import Key

if TYPE_CHECKING:
import graphviz
import pandas


T = TypeVar('T')
KeyType = TypeVar('KeyType', bound=Key)

Expand Down Expand Up @@ -84,7 +89,7 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
return self.get(tp, **kwargs).compute()

def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.
Expand Down Expand Up @@ -194,3 +199,103 @@ def bind_and_call(
def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self._graph.nodes.items())
return pipeline_html_repr(nodes)


def get_mapped_node_names(
graph: DataGraph, base_name: type, *, index_names: Sequence[Hashable] | None = None
) -> pandas.Series:
"""
Given a graph with a mapped node with given base_name, return a series of
corresponding mapped names.
This is meant to be used in combination with :py:func:`DataGraph.map`.
If the mapped node depends on multiple indices, the index of the returned series
will have a multi-index.
Note that Pandas is not a dependency of Sciline and must be installed separately.
Parameters
----------
graph:
The data graph to get the mapped node names from.
base_name:
The base name of the mapped node to get the names for.
index_names:
Specifies the names of the indices of the mapped node. If not given this is
inferred from the graph, but the argument may be required to disambiguate
multiple mapped nodes with the same name.
Returns
-------
:
The series of node names corresponding to the mapped node.
"""
import pandas as pd
from cyclebane.graph import IndexValues, MappedNode, NodeName

candidates = [
node
for node in graph._cbgraph.graph.nodes
if isinstance(node, MappedNode) and node.name == base_name
]
if len(candidates) == 0:
raise ValueError(f"'{base_name}' is not a mapped node.")
if index_names is not None:
candidates = [
node for node in candidates if set(node.indices) == set(index_names)
]
if len(candidates) > 1:
raise ValueError(
f"Multiple mapped nodes with name '{base_name}' found: {candidates}"
)
# Drops unrelated indices
graph = graph[candidates[0]] # type: ignore[index]
indices = graph._cbgraph.indices
if index_names is not None:
indices = {name: indices[name] for name in indices if name in index_names}
index_names = tuple(indices)

index = pd.MultiIndex.from_product(indices.values(), names=index_names)
keys = tuple(NodeName(base_name, IndexValues(index_names, idx)) for idx in index)
if index.nlevels == 1: # Avoid more complicated MultiIndex if unnecessary
index = index.get_level_values(0)
return pd.Series(keys, index=index, name=base_name)


def compute_mapped(
pipeline: Pipeline,
base_name: type,
*,
index_names: Sequence[Hashable] | None = None,
) -> pandas.Series:
"""
Given a graph with a mapped node with given base_name, return a series of computed
results.
This is meant to be used in combination with :py:func:`Pipeline.map`.
If the mapped node depends on multiple indices, the index of the returned series
will have a multi-index.
Note that Pandas is not a dependency of Sciline and must be installed separately.
Parameters
----------
graph:
The data graph to get the mapped node names from.
base_name:
The base name of the mapped node to get the names for.
index_names:
Specifies the names of the indices of the mapped node. If not given this is
inferred from the graph, but the argument may be required to disambiguate
multiple mapped nodes with the same name.
Returns
-------
:
The series of computed results corresponding to the mapped node.
"""
key_series = get_mapped_node_names(
graph=pipeline, base_name=base_name, index_names=index_names
)
results = pipeline.compute(key_series)
return key_series.apply(lambda x: results[x])
Loading

0 comments on commit bd8e6f3

Please sign in to comment.