Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the usage of generator functions in nodes #2161

Merged
merged 12 commits into from
Jan 6, 2023
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ test:
test-no-spark:
pytest tests --no-cov --ignore tests/extras/datasets/spark --numprocesses 4 --dist loadfile

test-no-datasets:
pytest tests --no-cov --ignore tests/extras/datasets/ --numprocesses 4 --dist loadfile
idanov marked this conversation as resolved.
Show resolved Hide resolved

e2e-tests:
behave

Expand Down
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* Added new `OmegaConfLoader` which uses `OmegaConf` for loading and merging configuration.
* Added the `--conf-source` option to `kedro run`, allowing users to specify a source for project configuration for the run.
* Added `omegaconf` syntax as option for `--params`. Keys and values can now be separated by colons or equals signs.
* Added support for generator functions as nodes, i.e. using `yield` instead of return.
* Enable chunk-wise processing in nodes with generator functions.
* Save node outputs after every `yield` before proceeding with next chunk.

## Bug fixes and other changes
* Fix bug where `micropkg` manifest section in `pyproject.toml` isn't recognised as allowed configuration.
Expand Down
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ importlib-metadata>=3.6; python_version >= '3.8'
importlib_metadata>=3.6, <5.0; python_version < '3.8' # The "selectable" entry points were introduced in `importlib_metadata` 3.6 and Python 3.10. Bandit on Python 3.7 relies on a library with `importlib_metadata` < 5.0
importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resources` 1.3 and Python 3.9.
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
pip-tools~=6.12
pluggy~=1.0.0
Expand Down
53 changes: 37 additions & 16 deletions kedro/pipeline/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
from warnings import warn

from more_itertools import spy, unzip


class Node:
"""``Node`` is an auxiliary class facilitating the operations required to
Expand Down Expand Up @@ -397,38 +399,57 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]):

def _outputs_to_dictionary(self, outputs):
def _from_dict():
if set(self._outputs.keys()) != set(outputs.keys()):
result, iterator = outputs, None
# generator functions are lazy and we need a peek into their first output
if inspect.isgenerator(outputs):
(result,), iterator = spy(outputs)

keys = list(self._outputs.keys())
names = list(self._outputs.values())
if not isinstance(result, dict):
raise ValueError(
f"Failed to save outputs of node {self}.\n"
f"The node output is a dictionary, whereas the "
f"function output is {type(result)}."
)
if set(keys) != set(result.keys()):
raise ValueError(
f"Failed to save outputs of node {str(self)}.\n"
f"The node's output keys {set(outputs.keys())} do not match with "
f"the returned output's keys {set(self._outputs.keys())}."
f"The node's output keys {set(result.keys())} "
f"do not match with the returned output's keys {set(keys)}."
)
return {name: outputs[key] for key, name in self._outputs.items()}
if iterator:
exploded = map(lambda x: tuple(x[k] for k in keys), iterator)
idanov marked this conversation as resolved.
Show resolved Hide resolved
result = unzip(exploded)
else:
# evaluate this eagerly so we can reuse variable name
result = tuple(result[k] for k in keys)
idanov marked this conversation as resolved.
Show resolved Hide resolved
return dict(zip(names, result))

def _from_list():
if not isinstance(outputs, (list, tuple)):
result, iterator = outputs, None
# generator functions are lazy and we need a peek into their first output
if inspect.isgenerator(outputs):
(result,), iterator = spy(outputs)

if not isinstance(result, (list, tuple)):
raise ValueError(
f"Failed to save outputs of node {str(self)}.\n"
f"The node definition contains a list of "
f"outputs {self._outputs}, whereas the node function "
f"returned a '{type(outputs).__name__}'."
f"returned a '{type(result).__name__}'."
)
if len(outputs) != len(self._outputs):
if len(result) != len(self._outputs):
raise ValueError(
f"Failed to save outputs of node {str(self)}.\n"
f"The node function returned {len(outputs)} output(s), "
f"The node function returned {len(result)} output(s), "
f"whereas the node definition contains {len(self._outputs)} "
f"output(s)."
)

return dict(zip(self._outputs, outputs))

if isinstance(self._outputs, dict) and not isinstance(outputs, dict):
raise ValueError(
f"Failed to save outputs of node {self}.\n"
f"The node output is a dictionary, whereas the "
f"function output is not."
)
if iterator:
result = unzip(iterator)
return dict(zip(self._outputs, result))

if self._outputs is None:
return {}
Expand Down
33 changes: 31 additions & 2 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
implementations.
"""

import inspect
import itertools as it
import logging
from abc import ABC, abstractmethod
from collections import deque
Expand All @@ -12,8 +14,9 @@
as_completed,
wait,
)
from typing import Any, Dict, Iterable, List, Set
from typing import Any, Dict, Iterable, Iterator, List, Set

from more_itertools import interleave
from pluggy import PluginManager

from kedro.framework.hooks.manager import _NullPluginManager
Expand Down Expand Up @@ -294,10 +297,22 @@ def run_node(
asynchronously with threads. Defaults to False.
session_id: The session id of the pipeline run.

Raises:
ValueError: Raised if is_async is set to True for nodes wrapping
generator functions.

Returns:
The node argument.

"""
if is_async and inspect.isgeneratorfunction(node.func):
raise ValueError(
f"Async data loading and saving does not work with "
f"nodes wrapping generator functions. Please make "
f"sure you don't use `yield` anywhere "
f"in node {str(node)}."
)

if is_async:
node = _run_node_async(node, catalog, hook_manager, session_id)
else:
Expand Down Expand Up @@ -399,7 +414,21 @@ def _run_node_sequential(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)

for name, data in outputs.items():
items: Iterable = outputs.items()
# if all outputs are iterators, then the node is a generator node
idanov marked this conversation as resolved.
Show resolved Hide resolved
idanov marked this conversation as resolved.
Show resolved Hide resolved
if all(isinstance(d, Iterator) for d in outputs.values()):
# Python dictionaries are ordered so we are sure
# the keys and the chunk streams are in the same order
# [a, b, c]
keys = list(outputs.keys())
# [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]]
streams = list(outputs.values())
# zip an endless cycle of the keys
# with an interleaved iterator of the streams
# [(a, chunk_a), (b, chunk_b), ...] until all outputs complete
items = zip(it.cycle(keys), interleave(*streams))

for name, data in items:
hook_manager.hook.before_dataset_saved(dataset_name=name, data=data)
catalog.save(name, data)
hook_manager.hook.after_dataset_saved(dataset_name=name, data=data)
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_node_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_run_dict_diff_size(self, mocked_dataset):

class TestNodeRunInvalidOutput:
def test_miss_matching_output_types(self, mocked_dataset):
pattern = r"The node output is a dictionary, whereas the function "
pattern += r"output is not\."
pattern = "The node output is a dictionary, whereas the function "
pattern += "output is <class 'kedro.io.lambda_dataset.LambdaDataSet'>."
with pytest.raises(ValueError, match=pattern):
node(one_in_one_out, "ds1", dict(a="ds")).run(dict(ds1=mocked_dataset))

Expand Down
89 changes: 89 additions & 0 deletions tests/runner/test_run_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest

from kedro.framework.hooks.manager import _NullPluginManager
from kedro.pipeline import node
from kedro.runner import run_node


def generate_one():
yield from range(10)


def generate_tuple():
for i in range(10):
yield i, i * i


def generate_list():
for i in range(10):
yield [i, i * i]


def generate_dict():
for i in range(10):
yield dict(idx=i, square=i * i)


class TestRunGeneratorNode:
def test_generator_fail_async(self, mocker, catalog):
fake_dataset = mocker.Mock()
catalog.add("result", fake_dataset)
n = node(generate_one, inputs=None, outputs="result")

with pytest.raises(Exception, match="nodes wrapping generator functions"):
run_node(n, catalog, _NullPluginManager(), is_async=True)

def test_generator_node_one(self, mocker, catalog):
fake_dataset = mocker.Mock()
catalog.add("result", fake_dataset)
n = node(generate_one, inputs=None, outputs="result")
run_node(n, catalog, _NullPluginManager())

expected = [((i,),) for i in range(10)]
assert 10 == fake_dataset.save.call_count
assert fake_dataset.save.call_args_list == expected

def test_generator_node_tuple(self, mocker, catalog):
left = mocker.Mock()
right = mocker.Mock()
catalog.add("left", left)
catalog.add("right", right)
n = node(generate_tuple, inputs=None, outputs=["left", "right"])
run_node(n, catalog, _NullPluginManager())

expected_left = [((i,),) for i in range(10)]
expected_right = [((i * i,),) for i in range(10)]
assert 10 == left.save.call_count
assert left.save.call_args_list == expected_left
assert 10 == right.save.call_count
assert right.save.call_args_list == expected_right

def test_generator_node_list(self, mocker, catalog):
left = mocker.Mock()
right = mocker.Mock()
catalog.add("left", left)
catalog.add("right", right)
n = node(generate_list, inputs=None, outputs=["left", "right"])
run_node(n, catalog, _NullPluginManager())

expected_left = [((i,),) for i in range(10)]
expected_right = [((i * i,),) for i in range(10)]
assert 10 == left.save.call_count
assert left.save.call_args_list == expected_left
assert 10 == right.save.call_count
assert right.save.call_args_list == expected_right

def test_generator_node_dict(self, mocker, catalog):
left = mocker.Mock()
right = mocker.Mock()
catalog.add("left", left)
catalog.add("right", right)
n = node(generate_dict, inputs=None, outputs=dict(idx="left", square="right"))
run_node(n, catalog, _NullPluginManager())

expected_left = [((i,),) for i in range(10)]
expected_right = [((i * i,),) for i in range(10)]
assert 10 == left.save.call_count
assert left.save.call_args_list == expected_left
assert 10 == right.save.call_count
assert right.save.call_args_list == expected_right
idanov marked this conversation as resolved.
Show resolved Hide resolved