-
Notifications
You must be signed in to change notification settings - Fork 914
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable the usage of generator functions in nodes (#2161)
* Modify the node and run_node to enable generator nodes Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Add tests to cover all types of generator functions Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Fail on running a generator node with async load/save Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Lint my code changes Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Add changelog to RELEASE.md Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Simplify the usage of spy and clarify with a comment Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Improve error messaging * Improve readability slightly in certain places * Correct the expected error message in node run tests Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Revert the eager evaluation of the result in _from_dict Generators cannot refer to themselves in their definition, or they will fail when used. Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> * Add a comment for the eager evaluation Signed-off-by: Ivan Danov <idanov@users.noreply.github.com> Signed-off-by: Ivan Danov <idanov@users.noreply.github.com>
- Loading branch information
Showing
7 changed files
with
166 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import pytest | ||
|
||
from kedro.framework.hooks.manager import _NullPluginManager | ||
from kedro.pipeline import node | ||
from kedro.runner import run_node | ||
|
||
|
||
def generate_one(): | ||
yield from range(10) | ||
|
||
|
||
def generate_tuple(): | ||
for i in range(10): | ||
yield i, i * i | ||
|
||
|
||
def generate_list(): | ||
for i in range(10): | ||
yield [i, i * i] | ||
|
||
|
||
def generate_dict(): | ||
for i in range(10): | ||
yield dict(idx=i, square=i * i) | ||
|
||
|
||
class TestRunGeneratorNode: | ||
def test_generator_fail_async(self, mocker, catalog): | ||
fake_dataset = mocker.Mock() | ||
catalog.add("result", fake_dataset) | ||
n = node(generate_one, inputs=None, outputs="result") | ||
|
||
with pytest.raises(Exception, match="nodes wrapping generator functions"): | ||
run_node(n, catalog, _NullPluginManager(), is_async=True) | ||
|
||
def test_generator_node_one(self, mocker, catalog): | ||
fake_dataset = mocker.Mock() | ||
catalog.add("result", fake_dataset) | ||
n = node(generate_one, inputs=None, outputs="result") | ||
run_node(n, catalog, _NullPluginManager()) | ||
|
||
expected = [((i,),) for i in range(10)] | ||
assert 10 == fake_dataset.save.call_count | ||
assert fake_dataset.save.call_args_list == expected | ||
|
||
def test_generator_node_tuple(self, mocker, catalog): | ||
left = mocker.Mock() | ||
right = mocker.Mock() | ||
catalog.add("left", left) | ||
catalog.add("right", right) | ||
n = node(generate_tuple, inputs=None, outputs=["left", "right"]) | ||
run_node(n, catalog, _NullPluginManager()) | ||
|
||
expected_left = [((i,),) for i in range(10)] | ||
expected_right = [((i * i,),) for i in range(10)] | ||
assert 10 == left.save.call_count | ||
assert left.save.call_args_list == expected_left | ||
assert 10 == right.save.call_count | ||
assert right.save.call_args_list == expected_right | ||
|
||
def test_generator_node_list(self, mocker, catalog): | ||
left = mocker.Mock() | ||
right = mocker.Mock() | ||
catalog.add("left", left) | ||
catalog.add("right", right) | ||
n = node(generate_list, inputs=None, outputs=["left", "right"]) | ||
run_node(n, catalog, _NullPluginManager()) | ||
|
||
expected_left = [((i,),) for i in range(10)] | ||
expected_right = [((i * i,),) for i in range(10)] | ||
assert 10 == left.save.call_count | ||
assert left.save.call_args_list == expected_left | ||
assert 10 == right.save.call_count | ||
assert right.save.call_args_list == expected_right | ||
|
||
def test_generator_node_dict(self, mocker, catalog): | ||
left = mocker.Mock() | ||
right = mocker.Mock() | ||
catalog.add("left", left) | ||
catalog.add("right", right) | ||
n = node(generate_dict, inputs=None, outputs=dict(idx="left", square="right")) | ||
run_node(n, catalog, _NullPluginManager()) | ||
|
||
expected_left = [((i,),) for i in range(10)] | ||
expected_right = [((i * i,),) for i in range(10)] | ||
assert 10 == left.save.call_count | ||
assert left.save.call_args_list == expected_left | ||
assert 10 == right.save.call_count | ||
assert right.save.call_args_list == expected_right |