From ca6cac7446c6fe47acccabc9021c267c5fbb3239 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Fri, 23 Dec 2022 16:04:35 +0000 Subject: [PATCH 01/11] Modify the node and run_node to enable generator nodes Signed-off-by: Ivan Danov --- Makefile | 3 +++ dependency/requirements.txt | 1 + kedro/pipeline/node.py | 30 +++++++++++++++++++++--------- kedro/runner/runner.py | 18 +++++++++++++++++- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 56ba48978e..fa45cace83 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,9 @@ test: test-no-spark: pytest tests --no-cov --ignore tests/extras/datasets/spark --numprocesses 4 --dist loadfile +test-no-datasets: + pytest tests --no-cov --ignore tests/extras/datasets/ --numprocesses 4 --dist loadfile + e2e-tests: behave diff --git a/dependency/requirements.txt b/dependency/requirements.txt index c622c875bc..8dfc92b54b 100644 --- a/dependency/requirements.txt +++ b/dependency/requirements.txt @@ -19,3 +19,4 @@ rope~=1.6.0 # subject to LGPLv3 license setuptools>=65.5.1 toml~=0.10 toposort~=1.7 # Needs to be at least 1.5 to be able to raise CircularDependencyError +more_itertools~=9.0 diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 3a4439c7aa..2de69e6071 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -9,6 +9,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union from warnings import warn +from more_itertools import spy, unzip + class Node: """``Node`` is an auxiliary class facilitating the operations required to @@ -397,31 +399,41 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]): def _outputs_to_dictionary(self, outputs): def _from_dict(): - if set(self._outputs.keys()) != set(outputs.keys()): + (result, ), iterator = (outputs, ), None if not inspect.isgenerator(outputs) else spy(outputs, 1) + keys = list(self._outputs.keys()) + if set(keys) != set(result.keys()): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" - f"The node's output keys {set(outputs.keys())} do not match with " - f"the returned output's keys {set(self._outputs.keys())}." + f"The node's output keys {set(result.keys())} " + f"do not match with the returned output's keys {set(keys)}." ) - return {name: outputs[key] for key, name in self._outputs.items()} + if iterator: + exploded = map(lambda x: tuple((x[k] for k in keys)), iterator) + result = unzip(exploded) + else: + result = tuple((result[k] for k in keys)) + return dict(zip([self._outputs[k] for k in keys], result)) def _from_list(): - if not isinstance(outputs, (list, tuple)): + (result, ), iterator = (outputs, ), None if not inspect.isgenerator(outputs) else spy(outputs, 1) + if not isinstance(result, (list, tuple)): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" f"The node definition contains a list of " f"outputs {self._outputs}, whereas the node function " - f"returned a '{type(outputs).__name__}'." + f"returned a '{type(result).__name__}'." ) - if len(outputs) != len(self._outputs): + if len(result) != len(self._outputs): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" - f"The node function returned {len(outputs)} output(s), " + f"The node function returned {len(result)} output(s), " f"whereas the node definition contains {len(self._outputs)} " f"output(s)." ) - return dict(zip(self._outputs, outputs)) + if iterator: + result = unzip(iterator) + return dict(zip(self._outputs, result)) if isinstance(self._outputs, dict) and not isinstance(outputs, dict): raise ValueError( diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 70f6b127e8..1d9274f071 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -15,6 +15,9 @@ from typing import Any, Dict, Iterable, List, Set from pluggy import PluginManager +from typing import Iterator +from more_itertools import interleave +import itertools as it from kedro.framework.hooks.manager import _NullPluginManager from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet @@ -399,7 +402,20 @@ def _run_node_sequential( node, catalog, inputs, is_async, hook_manager, session_id=session_id ) - for name, data in outputs.items(): + items = outputs.items() + # if all outputs are iterators, then the node is a generator node + if all((isinstance(d, Iterator) for d in outputs.values())): + # make sure we extract the keys and the chunk streams in the same order + # [a, b, c] + keys = list(outputs.keys()) + # [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]] + streams = [outputs[k] for k in outputs] + # zip an endless cycle of the keys + # with an interleaved iterator of the streams + # [(a, chunk_a), (b, chunk_b), ...] until all outputs complete + items = zip(it.cycle(keys), interleave(*streams)) + + for name, data in items: hook_manager.hook.before_dataset_saved(dataset_name=name, data=data) catalog.save(name, data) hook_manager.hook.after_dataset_saved(dataset_name=name, data=data) From fe6afdbf216be6f0f98ae308992d621867a92b18 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Fri, 23 Dec 2022 17:21:09 +0000 Subject: [PATCH 02/11] Add tests to cover all types of generator functions Signed-off-by: Ivan Danov --- kedro/pipeline/node.py | 17 ++++--- tests/runner/test_run_node.py | 83 +++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 tests/runner/test_run_node.py diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 2de69e6071..a3465fd2e8 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -399,8 +399,14 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]): def _outputs_to_dictionary(self, outputs): def _from_dict(): - (result, ), iterator = (outputs, ), None if not inspect.isgenerator(outputs) else spy(outputs, 1) + (result, ), iterator = ((outputs, ), None) if not inspect.isgenerator(outputs) else spy(outputs, 1) keys = list(self._outputs.keys()) + if not isinstance(result, dict): + raise ValueError( + f"Failed to save outputs of node {self}.\n" + f"The node output is a dictionary, whereas the " + f"function output is not." + ) if set(keys) != set(result.keys()): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" @@ -415,7 +421,7 @@ def _from_dict(): return dict(zip([self._outputs[k] for k in keys], result)) def _from_list(): - (result, ), iterator = (outputs, ), None if not inspect.isgenerator(outputs) else spy(outputs, 1) + (result, ), iterator = ((outputs, ), None) if not inspect.isgenerator(outputs) else spy(outputs, 1) if not isinstance(result, (list, tuple)): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" @@ -435,13 +441,6 @@ def _from_list(): result = unzip(iterator) return dict(zip(self._outputs, result)) - if isinstance(self._outputs, dict) and not isinstance(outputs, dict): - raise ValueError( - f"Failed to save outputs of node {self}.\n" - f"The node output is a dictionary, whereas the " - f"function output is not." - ) - if self._outputs is None: return {} if isinstance(self._outputs, str): diff --git a/tests/runner/test_run_node.py b/tests/runner/test_run_node.py new file mode 100644 index 0000000000..3c13bbe467 --- /dev/null +++ b/tests/runner/test_run_node.py @@ -0,0 +1,83 @@ +from kedro.framework.hooks.manager import _NullPluginManager +from kedro.pipeline import node +from kedro.runner import run_node + + +def generate_one(): + for i in range(10): + yield i + + +def generate_tuple(): + for i in range(10): + yield i, i * i + + +def generate_list(): + for i in range(10): + yield [i, i*i] + + +def generate_dict(): + for i in range(10): + yield dict(idx=i, square=i*i) + + +class TestRunGeneratorNode: + def test_generator_node_one(self, mocker, catalog): + fake_dataset = mocker.Mock() + catalog.add("result", fake_dataset) + n = node(generate_one, inputs=None, outputs="result") + run_node(n, catalog, _NullPluginManager()) + + expected = [((i,),) for i in range(10)] + assert 10 == fake_dataset.save.call_count + assert fake_dataset.save.call_args_list == expected + + def test_generator_node_tuple(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_tuple, inputs=None, outputs=["left", "right"]) + run_node(n, catalog, _NullPluginManager()) + + expected_left = [((i,),) for i in range(10)] + expected_right = [((i*i,),) for i in range(10)] + assert 10 == left.save.call_count + assert left.save.call_args_list == expected_left + assert 10 == right.save.call_count + assert right.save.call_args_list == expected_right + + def test_generator_node_list(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_list, inputs=None, outputs=["left", "right"]) + run_node(n, catalog, _NullPluginManager()) + + expected_left = [((i,),) for i in range(10)] + expected_right = [((i*i,),) for i in range(10)] + assert 10 == left.save.call_count + assert left.save.call_args_list == expected_left + assert 10 == right.save.call_count + assert right.save.call_args_list == expected_right + + def test_generator_node_dict(self, mocker, catalog): + left = mocker.Mock() + right = mocker.Mock() + catalog.add("left", left) + catalog.add("right", right) + n = node(generate_dict, + inputs=None, + outputs=dict(idx="left", square="right") + ) + run_node(n, catalog, _NullPluginManager()) + + expected_left = [((i,),) for i in range(10)] + expected_right = [((i*i,),) for i in range(10)] + assert 10 == left.save.call_count + assert left.save.call_args_list == expected_left + assert 10 == right.save.call_count + assert right.save.call_args_list == expected_right From d283b9e69fb80fece4c9737419728f0d8fab07de Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Fri, 23 Dec 2022 17:46:05 +0000 Subject: [PATCH 03/11] Fail on running a generator node with async load/save Signed-off-by: Ivan Danov --- kedro/runner/runner.py | 8 +++++++- tests/runner/test_run_node.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 1d9274f071..222845829c 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -23,6 +23,7 @@ from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet from kedro.pipeline import Pipeline from kedro.pipeline.node import Node +import inspect class AbstractRunner(ABC): @@ -301,7 +302,12 @@ def run_node( The node argument. """ - if is_async: + if is_async and inspect.isgeneratorfunction(node.func): + raise TypeError(f"Async data loading and saving does not work with " + f"nodes wrapping generator functions. Please make " + f"sure you don't use yield anywhere in your function " + f"in node {str(node)}") + elif is_async: node = _run_node_async(node, catalog, hook_manager, session_id) else: node = _run_node_sequential(node, catalog, hook_manager, session_id) diff --git a/tests/runner/test_run_node.py b/tests/runner/test_run_node.py index 3c13bbe467..2e0bb41d66 100644 --- a/tests/runner/test_run_node.py +++ b/tests/runner/test_run_node.py @@ -1,3 +1,5 @@ +import pytest + from kedro.framework.hooks.manager import _NullPluginManager from kedro.pipeline import node from kedro.runner import run_node @@ -24,6 +26,14 @@ def generate_dict(): class TestRunGeneratorNode: + def test_generator_fail_async(self, mocker, catalog): + fake_dataset = mocker.Mock() + catalog.add("result", fake_dataset) + n = node(generate_one, inputs=None, outputs="result") + + with pytest.raises(Exception, match="nodes wrapping generator functions"): + run_node(n, catalog, _NullPluginManager(), is_async=True) + def test_generator_node_one(self, mocker, catalog): fake_dataset = mocker.Mock() catalog.add("result", fake_dataset) From 65089a6b0d1500378b7091297905e3585b027ce4 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Fri, 23 Dec 2022 18:03:24 +0000 Subject: [PATCH 04/11] Lint my code changes Signed-off-by: Ivan Danov --- dependency/requirements.txt | 2 +- kedro/pipeline/node.py | 16 ++++++++++++---- kedro/runner/runner.py | 30 ++++++++++++++++++------------ tests/runner/test_run_node.py | 18 +++++++----------- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/dependency/requirements.txt b/dependency/requirements.txt index 8dfc92b54b..154bfef80f 100644 --- a/dependency/requirements.txt +++ b/dependency/requirements.txt @@ -10,6 +10,7 @@ importlib-metadata>=3.6; python_version >= '3.8' importlib_metadata>=3.6, <5.0; python_version < '3.8' # The "selectable" entry points were introduced in `importlib_metadata` 3.6 and Python 3.10. Bandit on Python 3.7 relies on a library with `importlib_metadata` < 5.0 importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resources` 1.3 and Python 3.9. jmespath>=0.9.5, <1.0 +more_itertools~=9.0 omegaconf~=2.3 pip-tools~=6.12 pluggy~=1.0.0 @@ -19,4 +20,3 @@ rope~=1.6.0 # subject to LGPLv3 license setuptools>=65.5.1 toml~=0.10 toposort~=1.7 # Needs to be at least 1.5 to be able to raise CircularDependencyError -more_itertools~=9.0 diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index a3465fd2e8..f13cf5ee21 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -399,7 +399,11 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]): def _outputs_to_dictionary(self, outputs): def _from_dict(): - (result, ), iterator = ((outputs, ), None) if not inspect.isgenerator(outputs) else spy(outputs, 1) + (result,), iterator = ( + ((outputs,), None) + if not inspect.isgenerator(outputs) + else spy(outputs, 1) + ) keys = list(self._outputs.keys()) if not isinstance(result, dict): raise ValueError( @@ -414,14 +418,18 @@ def _from_dict(): f"do not match with the returned output's keys {set(keys)}." ) if iterator: - exploded = map(lambda x: tuple((x[k] for k in keys)), iterator) + exploded = map(lambda x: tuple(x[k] for k in keys), iterator) result = unzip(exploded) else: - result = tuple((result[k] for k in keys)) + result = tuple(result[k] for k in keys) return dict(zip([self._outputs[k] for k in keys], result)) def _from_list(): - (result, ), iterator = ((outputs, ), None) if not inspect.isgenerator(outputs) else spy(outputs, 1) + (result,), iterator = ( + ((outputs,), None) + if not inspect.isgenerator(outputs) + else spy(outputs, 1) + ) if not isinstance(result, (list, tuple)): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 222845829c..4b14622717 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -2,6 +2,8 @@ implementations. """ +import inspect +import itertools as it import logging from abc import ABC, abstractmethod from collections import deque @@ -12,18 +14,15 @@ as_completed, wait, ) -from typing import Any, Dict, Iterable, List, Set +from typing import Any, Dict, Iterable, Iterator, List, Set -from pluggy import PluginManager -from typing import Iterator from more_itertools import interleave -import itertools as it +from pluggy import PluginManager from kedro.framework.hooks.manager import _NullPluginManager from kedro.io import AbstractDataSet, DataCatalog, MemoryDataSet from kedro.pipeline import Pipeline from kedro.pipeline.node import Node -import inspect class AbstractRunner(ABC): @@ -298,16 +297,23 @@ def run_node( asynchronously with threads. Defaults to False. session_id: The session id of the pipeline run. + Raises: + ValueError: Raised if is_async is set to True for nodes wrapping + generator functions. + Returns: The node argument. """ if is_async and inspect.isgeneratorfunction(node.func): - raise TypeError(f"Async data loading and saving does not work with " - f"nodes wrapping generator functions. Please make " - f"sure you don't use yield anywhere in your function " - f"in node {str(node)}") - elif is_async: + raise ValueError( + f"Async data loading and saving does not work with " + f"nodes wrapping generator functions. Please make " + f"sure you don't use yield anywhere in your function " + f"in node {str(node)}" + ) + + if is_async: node = _run_node_async(node, catalog, hook_manager, session_id) else: node = _run_node_sequential(node, catalog, hook_manager, session_id) @@ -408,9 +414,9 @@ def _run_node_sequential( node, catalog, inputs, is_async, hook_manager, session_id=session_id ) - items = outputs.items() + items: Iterable = outputs.items() # if all outputs are iterators, then the node is a generator node - if all((isinstance(d, Iterator) for d in outputs.values())): + if all(isinstance(d, Iterator) for d in outputs.values()): # make sure we extract the keys and the chunk streams in the same order # [a, b, c] keys = list(outputs.keys()) diff --git a/tests/runner/test_run_node.py b/tests/runner/test_run_node.py index 2e0bb41d66..8cf98e461c 100644 --- a/tests/runner/test_run_node.py +++ b/tests/runner/test_run_node.py @@ -6,8 +6,7 @@ def generate_one(): - for i in range(10): - yield i + yield from range(10) def generate_tuple(): @@ -17,12 +16,12 @@ def generate_tuple(): def generate_list(): for i in range(10): - yield [i, i*i] + yield [i, i * i] def generate_dict(): for i in range(10): - yield dict(idx=i, square=i*i) + yield dict(idx=i, square=i * i) class TestRunGeneratorNode: @@ -53,7 +52,7 @@ def test_generator_node_tuple(self, mocker, catalog): run_node(n, catalog, _NullPluginManager()) expected_left = [((i,),) for i in range(10)] - expected_right = [((i*i,),) for i in range(10)] + expected_right = [((i * i,),) for i in range(10)] assert 10 == left.save.call_count assert left.save.call_args_list == expected_left assert 10 == right.save.call_count @@ -68,7 +67,7 @@ def test_generator_node_list(self, mocker, catalog): run_node(n, catalog, _NullPluginManager()) expected_left = [((i,),) for i in range(10)] - expected_right = [((i*i,),) for i in range(10)] + expected_right = [((i * i,),) for i in range(10)] assert 10 == left.save.call_count assert left.save.call_args_list == expected_left assert 10 == right.save.call_count @@ -79,14 +78,11 @@ def test_generator_node_dict(self, mocker, catalog): right = mocker.Mock() catalog.add("left", left) catalog.add("right", right) - n = node(generate_dict, - inputs=None, - outputs=dict(idx="left", square="right") - ) + n = node(generate_dict, inputs=None, outputs=dict(idx="left", square="right")) run_node(n, catalog, _NullPluginManager()) expected_left = [((i,),) for i in range(10)] - expected_right = [((i*i,),) for i in range(10)] + expected_right = [((i * i,),) for i in range(10)] assert 10 == left.save.call_count assert left.save.call_args_list == expected_left assert 10 == right.save.call_count From a71ea805923027b29d2db2b6c48fee77cb3944ed Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Fri, 23 Dec 2022 23:31:47 +0000 Subject: [PATCH 05/11] Add changelog to RELEASE.md Signed-off-by: Ivan Danov --- RELEASE.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index 76530bef14..70ebfba574 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -14,6 +14,9 @@ * Added new `OmegaConfLoader` which uses `OmegaConf` for loading and merging configuration. * Added the `--conf-source` option to `kedro run`, allowing users to specify a source for project configuration for the run. * Added `omegaconf` syntax as option for `--params`. Keys and values can now be separated by colons or equals signs. +* Added support for generator functions as nodes, i.e. using `yield` instead of return. + * Enable chunk-wise processing in nodes with generator functions. + * Save node outputs after every `yield` before proceeding with next chunk. ## Bug fixes and other changes * Fix bug where `micropkg` manifest section in `pyproject.toml` isn't recognised as allowed configuration. From 193c76522d68a542588169430f14952541ed95b0 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Wed, 4 Jan 2023 16:43:17 +0000 Subject: [PATCH 06/11] Simplify the usage of spy and clarify with a comment Signed-off-by: Ivan Danov --- kedro/pipeline/node.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index f13cf5ee21..d4c34e91b0 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -399,11 +399,11 @@ def _run_with_dict(self, inputs: Dict[str, Any], node_inputs: Dict[str, str]): def _outputs_to_dictionary(self, outputs): def _from_dict(): - (result,), iterator = ( - ((outputs,), None) - if not inspect.isgenerator(outputs) - else spy(outputs, 1) - ) + result, iterator = outputs, None + # generator functions are lazy and we need a peek into their first output + if inspect.isgenerator(outputs): + (result,), iterator = spy(outputs) + keys = list(self._outputs.keys()) if not isinstance(result, dict): raise ValueError( @@ -425,11 +425,11 @@ def _from_dict(): return dict(zip([self._outputs[k] for k in keys], result)) def _from_list(): - (result,), iterator = ( - ((outputs,), None) - if not inspect.isgenerator(outputs) - else spy(outputs, 1) - ) + result, iterator = outputs, None + # generator functions are lazy and we need a peek into their first output + if inspect.isgenerator(outputs): + (result,), iterator = spy(outputs) + if not isinstance(result, (list, tuple)): raise ValueError( f"Failed to save outputs of node {str(self)}.\n" From bb81cb06585ba7329acdc61dd7d122893441037f Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Wed, 4 Jan 2023 17:08:24 +0000 Subject: [PATCH 07/11] Improve error messaging --- kedro/pipeline/node.py | 2 +- kedro/runner/runner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index d4c34e91b0..c9380c1ac5 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -409,7 +409,7 @@ def _from_dict(): raise ValueError( f"Failed to save outputs of node {self}.\n" f"The node output is a dictionary, whereas the " - f"function output is not." + f"function output is {type(result)}." ) if set(keys) != set(result.keys()): raise ValueError( diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 4b14622717..8c869b6af5 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -309,8 +309,8 @@ def run_node( raise ValueError( f"Async data loading and saving does not work with " f"nodes wrapping generator functions. Please make " - f"sure you don't use yield anywhere in your function " - f"in node {str(node)}" + f"sure you don't use `yield` anywhere " + f"in node {str(node)}." ) if is_async: From 377730903d61c82f1bc900a4b215847395baf695 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Wed, 4 Jan 2023 17:19:41 +0000 Subject: [PATCH 08/11] Improve readability slightly in certain places --- kedro/pipeline/node.py | 2 +- kedro/runner/runner.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index c9380c1ac5..f2166c4f98 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -421,7 +421,7 @@ def _from_dict(): exploded = map(lambda x: tuple(x[k] for k in keys), iterator) result = unzip(exploded) else: - result = tuple(result[k] for k in keys) + result = (result[k] for k in keys) return dict(zip([self._outputs[k] for k in keys], result)) def _from_list(): diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 8c869b6af5..89947a0e09 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -417,11 +417,12 @@ def _run_node_sequential( items: Iterable = outputs.items() # if all outputs are iterators, then the node is a generator node if all(isinstance(d, Iterator) for d in outputs.values()): - # make sure we extract the keys and the chunk streams in the same order + # Python dictionaries are ordered so we are sure + # the keys and the chunk streams are in the same order # [a, b, c] keys = list(outputs.keys()) # [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]] - streams = [outputs[k] for k in outputs] + streams = list(outputs.values()) # zip an endless cycle of the keys # with an interleaved iterator of the streams # [(a, chunk_a), (b, chunk_b), ...] until all outputs complete From 7b59a7011466ad9b86a2880fdf6d663129caec09 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Wed, 4 Jan 2023 20:40:05 +0000 Subject: [PATCH 09/11] Correct the expected error message in node run tests Signed-off-by: Ivan Danov --- tests/pipeline/test_node_run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipeline/test_node_run.py b/tests/pipeline/test_node_run.py index bcf1778b46..331e6cd252 100644 --- a/tests/pipeline/test_node_run.py +++ b/tests/pipeline/test_node_run.py @@ -109,8 +109,8 @@ def test_run_dict_diff_size(self, mocked_dataset): class TestNodeRunInvalidOutput: def test_miss_matching_output_types(self, mocked_dataset): - pattern = r"The node output is a dictionary, whereas the function " - pattern += r"output is not\." + pattern = "The node output is a dictionary, whereas the function " + pattern += "output is ." with pytest.raises(ValueError, match=pattern): node(one_in_one_out, "ds1", dict(a="ds")).run(dict(ds1=mocked_dataset)) From ad49b50a6da35d87f81c21cfe119e3aad3a5b2ee Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Wed, 4 Jan 2023 20:42:37 +0000 Subject: [PATCH 10/11] Revert the eager evaluation of the result in _from_dict Generators cannot refer to themselves in their definition, or they will fail when used. Signed-off-by: Ivan Danov --- kedro/pipeline/node.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index f2166c4f98..8f50d4d209 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -405,6 +405,7 @@ def _from_dict(): (result,), iterator = spy(outputs) keys = list(self._outputs.keys()) + names = list(self._outputs.values()) if not isinstance(result, dict): raise ValueError( f"Failed to save outputs of node {self}.\n" @@ -421,8 +422,8 @@ def _from_dict(): exploded = map(lambda x: tuple(x[k] for k in keys), iterator) result = unzip(exploded) else: - result = (result[k] for k in keys) - return dict(zip([self._outputs[k] for k in keys], result)) + result = tuple(result[k] for k in keys) + return dict(zip(names, result)) def _from_list(): result, iterator = outputs, None From 293a02c88ad21edc3dcdfef9f4bcf5684c6f6728 Mon Sep 17 00:00:00 2001 From: Ivan Danov Date: Wed, 4 Jan 2023 20:50:08 +0000 Subject: [PATCH 11/11] Add a comment for the eager evaluation Signed-off-by: Ivan Danov --- kedro/pipeline/node.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 8f50d4d209..4275ec94e1 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -422,6 +422,7 @@ def _from_dict(): exploded = map(lambda x: tuple(x[k] for k in keys), iterator) result = unzip(exploded) else: + # evaluate this eagerly so we can reuse variable name result = tuple(result[k] for k in keys) return dict(zip(names, result))