Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resume suggestions #3719

Merged
merged 13 commits into from
Apr 2, 2024
171 changes: 126 additions & 45 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
as_completed,
wait,
)
from typing import Any, Iterable, Iterator
from typing import Any, Collection, Iterable, Iterator

from more_itertools import interleave
from pluggy import PluginManager
Expand Down Expand Up @@ -198,39 +198,64 @@ def _suggest_resume_scenario(

postfix = ""
if done_nodes:
node_names = (n.name for n in remaining_nodes)
resume_p = pipeline.only_nodes(*node_names)
start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs())

# find the nearest persistent ancestors of the nodes in start_p
start_p_persistent_ancestors = _find_persistent_ancestors(
pipeline, start_p.nodes, catalog
start_node_names = find_nodes_to_resume_from(
pipeline=pipeline,
unfinished_nodes=remaining_nodes,
catalog=catalog,
)

start_node_names = (n.name for n in start_p_persistent_ancestors)
postfix += f" --from-nodes \"{','.join(start_node_names)}\""
start_nodes_str = ",".join(sorted(start_node_names))
postfix += f' --from-nodes "{start_nodes_str}"'

if not postfix:
self._logger.warning(
"No nodes ran. Repeat the previous command to attempt a new run."
)
else:
self._logger.warning(
"There are %d nodes that have not run.\n"
f"There are {len(remaining_nodes)} nodes that have not run.\n"
"You can resume the pipeline run from the nearest nodes with "
"persisted inputs by adding the following "
"argument to your previous command:\n%s",
len(remaining_nodes),
postfix,
f"argument to your previous command:\n{postfix}"
)


def _find_persistent_ancestors(
pipeline: Pipeline, children: Iterable[Node], catalog: DataCatalog
def find_nodes_to_resume_from(
merelcht marked this conversation as resolved.
Show resolved Hide resolved
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: DataCatalog
) -> set[str]:
"""Given a collection of unfinished nodes in a pipeline using
a certain catalog, find the node names to pass to pipeline.from_nodes()
to cover all unfinished nodes, including any additional nodes
that should be re-run if their outputs are not persisted.

Args:
pipeline: the ``Pipeline`` to find starting nodes for.
unfinished_nodes: collection of ``Node``s that have not finished yet
catalog: the ``DataCatalog`` of the run.

Returns:
Set of node names to pass to pipeline.from_nodes() to continue
the run.

"""
all_nodes_that_need_to_run = _find_all_required_nodes(
pipeline, unfinished_nodes, catalog
)

# Find which of the remaining nodes would need to run first (in topo sort)
persistent_ancestors = _find_initial_node_group(
pipeline, all_nodes_that_need_to_run
)
Copy link
Contributor

@noklam noklam Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this a bit hard to understand the purpose of this, is it similar to this?

p = pipeline(all_nodes_tht_need_to_run)
p_inputs = p.inputs()
input_nodes = p.only_nodes_with_input(p_inputs)
input_nodes_names = [n.name for n in input_nodes]

It's just pseudocode so I don't guarantee it works

Copy link
Contributor Author

@ondrejzacha ondrejzacha Mar 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not exactly the same. The difference appears when nodes share the same external inputs:

Pipeline([
    node(
        name="first_node",
        func=...,
        inputs=["external_input"],
        outputs="intermediate_output",
    ),
    node(
        name="second_node",
        func=...,
        inputs=["external_input", "intermediate_output"],
        outputs="final_output",
    )
])

Your suggestion would produce ["first_node", "second_node"] here, whereas with the topo sort group approach, only the ["first_node"] is produced (as this pipeline will have two topo sort groups, each containing a single node).


return {n.name for n in persistent_ancestors}


def _find_all_required_nodes(
DimedS marked this conversation as resolved.
Show resolved Hide resolved
pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: DataCatalog
) -> set[Node]:
"""Breadth-first search approach to finding the complete set of
persistent ancestors of an iterable of ``Node``s. Persistent
ancestors exclusively have persisted ``Dataset``s as inputs.
``Node``s which need to run to cover all unfinished nodes,
including any additional nodes that should be re-run if their outputs
are not persisted.

Args:
pipeline: the ``Pipeline`` to find ancestors in.
merelcht marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -242,54 +267,110 @@ def _find_persistent_ancestors(
``Node``s.

"""
ancestor_nodes_to_run = set()
queue, visited = deque(children), set(children)
nodes_to_run = set(unfinished_nodes)
initial_nodes = _nodes_with_external_inputs(unfinished_nodes)

queue, visited = deque(initial_nodes), set(initial_nodes)
DimedS marked this conversation as resolved.
Show resolved Hide resolved
while queue:
current_node = queue.popleft()
if _has_persistent_inputs(current_node, catalog):
ancestor_nodes_to_run.add(current_node)
continue
for parent in _enumerate_parents(pipeline, current_node):
if parent in visited:
nodes_to_run.add(current_node)
non_persistent_inputs = _enumerate_non_persistent_inputs(current_node, catalog)
# Look for the nodes that produce non-persistent inputs (if those exist)
for node in _enumerate_nodes_with_outputs(pipeline, non_persistent_inputs):
if node in visited:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
continue
visited.add(parent)
queue.append(parent)
return ancestor_nodes_to_run
visited.add(node)
queue.append(node)

# Make sure no downstream tasks are skipped
nodes_to_run = pipeline.from_nodes(*(n.name for n in nodes_to_run)).nodes

return set(nodes_to_run)


def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]:
"""For a given ``Node``, returns a list containing the direct parents
of that ``Node`` in the given ``Pipeline``.
def _nodes_with_external_inputs(nodes_of_interest: Iterable[Node]) -> set[Node]:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
"""For given ``Node``s , find their subset which depends on
external inputs of the ``Pipeline`` they constitute.

Args:
pipeline: the ``Pipeline`` to search for direct parents in.
child: the ``Node`` to find parents of.
nodes_of_interest: the ``Node``s to analyze.

Returns:
A list of all ``Node``s that are direct parents of ``child``.
A set of ``Node``s that depend on external inputs
of nodes of interest.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*child.inputs)
return parent_pipeline.nodes
p_nodes_of_interest = Pipeline(nodes_of_interest)
p_nodes_with_external_inputs = p_nodes_of_interest.only_nodes_with_inputs(
*p_nodes_of_interest.inputs()
)
return set(p_nodes_with_external_inputs.nodes)


def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool:
"""Check if a ``Node`` exclusively has persisted Datasets as inputs.
If at least one input is a ``MemoryDataset``, return False.
def _enumerate_non_persistent_inputs(node: Node, catalog: DataCatalog) -> set[str]:
"""Enumerate non-persistent input datasets of a ``Node``.

Args:
node: the ``Node`` to check the inputs of.
catalog: the ``DataCatalog`` of the run.

Returns:
True if the ``Node`` being checked exclusively has inputs that
are not ``MemoryDataset``, else False.
Set of names of non-persistent inputs of given ``Node``.

"""
# We use _datasets because they pertain parameter name format
catalog_datasets = catalog._datasets
non_persistent_inputs: set[str] = set()
for node_input in node.inputs:
if isinstance(catalog._datasets[node_input], MemoryDataset):
return False
return True
if node_input.startswith("params:"):
continue
if node_input not in catalog_datasets or isinstance(
merelcht marked this conversation as resolved.
Show resolved Hide resolved
catalog_datasets[node_input], MemoryDataset
):
non_persistent_inputs.add(node_input)

return non_persistent_inputs


def _enumerate_nodes_with_outputs(
pipeline: Pipeline, outputs: Collection[str]
) -> list[Node]:
"""For given outputs, returns a list containing nodes that
generate them in the given ``Pipeline``.

Args:
pipeline: the ``Pipeline`` to search for nodes in.
outputs: the dataset names to find source nodes for.

Returns:
A list of all ``Node``s that are producing ``outputs``.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*outputs)
return parent_pipeline.nodes


def _find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]:
"""Given a collection of ``Node``s in a ``Pipeline``,
find the initial group of ``Node``s to be run (in topological order).

This can be used to define a sub-pipeline with the smallest possible
set of nodes to pass to --from-nodes.

Args:
pipeline: the ``Pipeline`` to search for initial ``Node``s in.
nodes: the ``Node``s to find initial group for.

Returns:
A list of initial ``Node``s to run given inputs (in topological order).

"""
node_names = set(n.name for n in nodes)
if len(node_names) == 0:
return []
sub_pipeline = pipeline.only_nodes(*node_names)
initial_nodes = sub_pipeline.grouped_nodes[0]
return initial_nodes


def run_node(
Expand Down
111 changes: 103 additions & 8 deletions tests/runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def identity(arg):
return arg


def first_arg(*args):
return args[0]


def sink(arg):
pass

Expand All @@ -36,7 +40,7 @@ def return_not_serialisable(arg):
return lambda x: x


def multi_input_list_output(arg1, arg2):
def multi_input_list_output(arg1, arg2, arg3=None):
return [arg1, arg2]


Expand Down Expand Up @@ -80,6 +84,9 @@ def _save(arg):
"ds0_B": persistent_dataset,
"ds2_A": persistent_dataset,
"ds2_B": persistent_dataset,
"dsX": persistent_dataset,
"dsY": persistent_dataset,
"params:p": MemoryDataset(1),
}
)

Expand Down Expand Up @@ -148,21 +155,31 @@ def unfinished_outputs_pipeline():

@pytest.fixture
def two_branches_crossed_pipeline():
"""A ``Pipeline`` with an X-shape (two branches with one common node)"""
r"""A ``Pipeline`` with an X-shape (two branches with one common node):

(node1_A) (node1_B)
\ /
(node2)
/ \
(node3_A) (node3_B)
/ \
(node4_A) (node4_B)

"""
return pipeline(
[
node(identity, "ds0_A", "ds1_A", name="node1_A"),
node(identity, "ds0_B", "ds1_B", name="node1_B"),
node(first_arg, "ds0_A", "ds1_A", name="node1_A"),
node(first_arg, "ds0_B", "ds1_B", name="node1_B"),
node(
multi_input_list_output,
["ds1_A", "ds1_B"],
["ds2_A", "ds2_B"],
name="node2",
),
node(identity, "ds2_A", "ds3_A", name="node3_A"),
node(identity, "ds2_B", "ds3_B", name="node3_B"),
node(identity, "ds3_A", "ds4_A", name="node4_A"),
node(identity, "ds3_B", "ds4_B", name="node4_B"),
node(first_arg, "ds2_A", "ds3_A", name="node3_A"),
node(first_arg, "ds2_B", "ds3_B", name="node3_B"),
node(first_arg, "ds3_A", "ds4_A", name="node4_A"),
node(first_arg, "ds3_B", "ds4_B", name="node4_B"),
]
)

Expand All @@ -175,3 +192,81 @@ def pipeline_with_memory_datasets():
node(func=identity, inputs="Input2", outputs="MemOutput2", name="node2"),
]
)


@pytest.fixture
def pipeline_asymmetric():
r"""

(node1)
\
(node3) (node2)
\ /
(node4)

"""
return pipeline(
[
node(first_arg, ["ds0_A"], ["_ds1"], name="node1"),
node(first_arg, ["ds0_B"], ["_ds2"], name="node2"),
node(first_arg, ["_ds1"], ["_ds3"], name="node3"),
node(first_arg, ["_ds2", "_ds3"], ["_ds4"], name="node4"),
]
)


@pytest.fixture
def pipeline_triangular():
r"""

(node1)
| \
| (node2)
| /
(node3)

"""
return pipeline(
[
node(first_arg, ["ds0_A"], ["_ds1_A"], name="node1"),
node(first_arg, ["_ds1_A"], ["ds2_A"], name="node2"),
node(first_arg, ["ds2_A", "_ds1_A"], ["_ds3_A"], name="node3"),
]
)


@pytest.fixture
def empty_pipeline():
return pipeline([])


@pytest.fixture(
params=[(), ("dsX",), ("params:p",)],
ids=[
"no_extras",
"extra_persistent_ds",
"extra_param",
],
)
def two_branches_crossed_pipeline_variable_inputs(request):
"""A ``Pipeline`` with an X-shape (two branches with one common node).
Non-persistent datasets (other than parameters) are prefixed with an underscore.
"""
extra_inputs = list(request.param)

return pipeline(
[
node(first_arg, ["ds0_A"] + extra_inputs, "_ds1_A", name="node1_A"),
node(first_arg, ["ds0_B"] + extra_inputs, "_ds1_B", name="node1_B"),
node(
multi_input_list_output,
["_ds1_A", "_ds1_B"] + extra_inputs,
["ds2_A", "ds2_B"],
name="node2",
),
node(first_arg, ["ds2_A"] + extra_inputs, "_ds3_A", name="node3_A"),
node(first_arg, ["ds2_B"] + extra_inputs, "_ds3_B", name="node3_B"),
node(first_arg, ["_ds3_A"] + extra_inputs, "_ds4_A", name="node4_A"),
node(first_arg, ["_ds3_B"] + extra_inputs, "_ds4_B", name="node4_B"),
]
)
Loading
Loading