diff --git a/Makefile b/Makefile index 56ba48978e..fa45cace83 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,9 @@ test: test-no-spark: pytest tests --no-cov --ignore tests/extras/datasets/spark --numprocesses 4 --dist loadfile +test-no-datasets: + pytest tests --no-cov --ignore tests/extras/datasets/ --numprocesses 4 --dist loadfile + e2e-tests: behave diff --git a/RELEASE.md b/RELEASE.md index 76530bef14..70ebfba574 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -14,6 +14,9 @@ * Added new `OmegaConfLoader` which uses `OmegaConf` for loading and merging configuration. * Added the `--conf-source` option to `kedro run`, allowing users to specify a source for project configuration for the run. * Added `omegaconf` syntax as option for `--params`. Keys and values can now be separated by colons or equals signs. +* Added support for generator functions as nodes, i.e. using `yield` instead of return. + * Enable chunk-wise processing in nodes with generator functions. + * Save node outputs after every `yield` before proceeding with next chunk. ## Bug fixes and other changes * Fix bug where `micropkg` manifest section in `pyproject.toml` isn't recognised as allowed configuration. diff --git a/dependency/requirements.txt b/dependency/requirements.txt index 5eeda39152..b7acadc38b 100644 --- a/dependency/requirements.txt +++ b/dependency/requirements.txt @@ -10,6 +10,7 @@ importlib-metadata>=3.6; python_version >= '3.8' importlib_metadata>=3.6, <5.0; python_version < '3.8' # The "selectable" entry points were introduced in `importlib_metadata` 3.6 and Python 3.10. Bandit on Python 3.7 relies on a library with `importlib_metadata` < 5.0 importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resources` 1.3 and Python 3.9. jmespath>=0.9.5, <1.0 +more_itertools~=9.0 omegaconf~=2.3 pip-tools~=6.12 pluggy~=1.0.0 diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 3a4439c7aa..4275ec94e1 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -9,6 +9,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union from warnings import warn +from more_itertools import spy, unzip + class Node: """``Node`` is an auxiliary class facilitating the operations required to @@ -397,38 +399,57 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]): def _outputs_to_dictionary(self, outputs): def _from_dict(): - if set(self._outputs.keys()) != set(outputs.keys()): + result, iterator = outputs, None + # generator functions are lazy and we need a peek into their first output + if inspect.isgenerator(outputs): + (result,), iterator = spy(outputs) + + keys = list(self._outputs.keys()) + names = list(self._outputs.values()) + if not isinstance(result, dict): + raise ValueError( + f"Failed to save outputs of node {self}.\n" + f"The node output is a dictionary, whereas the " + f"function output is {type(result)}." + ) + if set(keys) != set(result.keys()): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" - f"The node's output keys {set(outputs.keys())} do not match with " - f"the returned output's keys {set(self._outputs.keys())}." + f"The node's output keys {set(result.keys())} " + f"do not match with the returned output's keys {set(keys)}." ) - return {name: outputs[key] for key, name in self._outputs.items()} + if iterator: + exploded = map(lambda x: tuple(x[k] for k in keys), iterator) + result = unzip(exploded) + else: + # evaluate this eagerly so we can reuse variable name + result = tuple(result[k] for k in keys) + return dict(zip(names, result)) def _from_list(): - if not isinstance(outputs, (list, tuple)): + result, iterator = outputs, None + # generator functions are lazy and we need a peek into their first output + if inspect.isgenerator(outputs): + (result,), iterator = spy(outputs) + + if not isinstance(result, (list, tuple)): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" f"The node definition contains a list of " f"outputs {self._outputs}, whereas the node function " - f"returned a '{type(outputs).__name__}'." + f"returned a '{type(result).__name__}'." ) - if len(outputs) != len(self._outputs): + if len(result) != len(self._outputs): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" - f"The node function returned {len(outputs)} output(s), " + f"The node function returned {len(result)} output(s), " f"whereas the node definition contains {len(self._outputs)} " f"output(s)." ) - return dict(zip(self._outputs, outputs)) - - if isinstance(self._outputs, dict) and not isinstance(outputs, dict): - raise ValueError( - f"Failed to save outputs of node {self}.\n" - f"The node output is a dictionary, whereas the " - f"function output is not." - ) + if iterator: + result = unzip(iterator) + return dict(zip(self._outputs, result)) if self._outputs is None: return {} diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 70f6b127e8..89947a0e09 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -2,6 +2,8 @@ implementations. """ +import inspect +import itertools as it import logging from abc import ABC, abstractmethod from collections import deque @@ -12,8 +14,9 @@ as_completed, wait, ) -from typing import Any, Dict, Iterable, List, Set +from typing import Any, Dict, Iterable, Iterator, List, Set +from more_itertools import interleave from pluggy import PluginManager from kedro.framework.hooks.manager import _NullPluginManager @@ -294,10 +297,22 @@ def run_node( asynchronously with threads. Defaults to False. session_id: The session id of the pipeline run. + Raises: + ValueError: Raised if is_async is set to True for nodes wrapping + generator functions. + Returns: The node argument. """ + if is_async and inspect.isgeneratorfunction(node.func): + raise ValueError( + f"Async data loading and saving does not work with " + f"nodes wrapping generator functions. Please make " + f"sure you don't use `yield` anywhere " + f"in node {str(node)}." + ) + if is_async: node = _run_node_async(node, catalog, hook_manager, session_id) else: @@ -399,7 +414,21 @@ def _run_node_sequential( node, catalog, inputs, is_async, hook_manager, session_id=session_id ) - for name, data in outputs.items(): + items: Iterable = outputs.items() + # if all outputs are iterators, then the node is a generator node + if all(isinstance(d, Iterator) for d in outputs.values()): + # Python dictionaries are ordered so we are sure + # the keys and the chunk streams are in the same order + # [a, b, c] + keys = list(outputs.keys()) + # [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]] + streams = list(outputs.values()) + # zip an endless cycle of the keys + # with an interleaved iterator of the streams + # [(a, chunk_a), (b, chunk_b), ...] until all outputs complete + items = zip(it.cycle(keys), interleave(*streams)) + + for name, data in items: hook_manager.hook.before_dataset_saved(dataset_name=name, data=data) catalog.save(name, data) hook_manager.hook.after_dataset_saved(dataset_name=name, data=data) diff --git a/tests/pipeline/test_node_run.py b/tests/pipeline/test_node_run.py index bcf1778b46..331e6cd252 100644 --- a/tests/pipeline/test_node_run.py +++ b/tests/pipeline/test_node_run.py @@ -109,8 +109,8 @@ def test_run_dict_diff_size(self, mocked_dataset): class TestNodeRunInvalidOutput: def test_miss_matching_output_types(self, mocked_dataset): - pattern = r"The node output is a dictionary, whereas the function " - pattern += r"output is not\." + pattern = "The node output is a dictionary, whereas the function " + pattern += "output is ." with pytest.raises(ValueError, match=pattern): node(one_in_one_out, "ds1", dict(a="ds")).run(dict(ds1=mocked_dataset)) diff --git a/tests/runner/test_run_node.py b/tests/runner/test_run_node.py new file mode 100644 index 0000000000..8cf98e461c --- /dev/null +++ b/tests/runner/test_run_node.py @@ -0,0 +1,89 @@ +import pytest + +from kedro.framework.hooks.manager import _NullPluginManager +from kedro.pipeline import node +from kedro.runner import run_node + + +def generate_one(): + yield from range(10) + + +def generate_tuple(): + for i in range(10): + yield i, i * i + + +def generate_list(): + for i in range(10): + yield [i, i * i] + + +def generate_dict(): + for i in range(10): + yield dict(idx=i, square=i * i) + + +class TestRunGeneratorNode: + def test_generator_fail_async(self, mocker, catalog): + fake_dataset = mocker.Mock() + catalog.add("result", fake_dataset) + n = node(generate_one, inputs=None, outputs="result") + + with pytest.raises(Exception, match="nodes wrapping generator functions"): + run_node(n, catalog, _NullPluginManager(), is_async=True) + + def test_generator_node_one(self, mocker, catalog): + fake_dataset = mocker.Mock() + catalog.add("result", fake_dataset) + n = node(generate_one, inputs=None, outputs="result") + run_node(n, catalog, _NullPluginManager()) + + expected = [((i,),) for i in range(10)] + assert 10 == fake_dataset.save.call_count + assert fake_dataset.save.call_args_list == expected + + def test_generator_node_tuple(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_tuple, inputs=None, outputs=["left", "right"]) + run_node(n, catalog, _NullPluginManager()) + + expected_left = [((i,),) for i in range(10)] + expected_right = [((i * i,),) for i in range(10)] + assert 10 == left.save.call_count + assert left.save.call_args_list == expected_left + assert 10 == right.save.call_count + assert right.save.call_args_list == expected_right + + def test_generator_node_list(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_list, inputs=None, outputs=["left", "right"]) + run_node(n, catalog, _NullPluginManager()) + + expected_left = [((i,),) for i in range(10)] + expected_right = [((i * i,),) for i in range(10)] + assert 10 == left.save.call_count + assert left.save.call_args_list == expected_left + assert 10 == right.save.call_count + assert right.save.call_args_list == expected_right + + def test_generator_node_dict(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_dict, inputs=None, outputs=dict(idx="left", square="right")) + run_node(n, catalog, _NullPluginManager()) + + expected_left = [((i,),) for i in range(10)] + expected_right = [((i * i,),) for i in range(10)] + assert 10 == left.save.call_count + assert left.save.call_args_list == expected_left + assert 10 == right.save.call_count + assert right.save.call_args_list == expected_right