From fa8c56fa2e510e6a449f5ac7356f76c167be978a Mon Sep 17 00:00:00 2001 From: Nok Lam Chan Date: Mon, 13 Feb 2023 16:44:25 +0800 Subject: [PATCH] Add node as an argument to all datasets' hook (#2296) * Add node as an argument to all datasets' hook Signed-off-by: Nok Chan * Fix missed node argument and conftest Signed-off-by: Nok Chan * Fix the tests, move the inner function out to make the node serializable Signed-off-by: Nok * Add docstring Signed-off-by: Nok * Fixed example and update releaes notes Signed-off-by: Nok Chan * Fix docstring Signed-off-by: Nok Chan * More linting Signed-off-by: Nok --------- Signed-off-by: Nok Chan Signed-off-by: Nok --- RELEASE.md | 1 + docs/source/hooks/common_use_cases.md | 40 ++++++++++++++----- kedro/framework/hooks/specs.py | 12 ++++-- kedro/runner/runner.py | 22 ++++++---- tests/framework/session/conftest.py | 21 ++++++---- .../session/test_session_extension_hooks.py | 27 ++++++++----- 6 files changed, 82 insertions(+), 41 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 3b218dee49..4eb89fb16f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -19,6 +19,7 @@ * Save node outputs after every `yield` before proceeding with next chunk. * Fixed incorrect parsing of Azure Data Lake Storage Gen2 URIs used in datasets. * Added support for loading credentials from environment variables using `OmegaConfigLoader`. +* Added a new argument `node` for all four dataset hooks. ## Bug fixes and other changes * Commas surrounded by square brackets (only possible for nodes with default names) will no longer split the arguments to `kedro run` options which take a list of nodes as inputs (`--from-nodes` and `--to-nodes`). diff --git a/docs/source/hooks/common_use_cases.md b/docs/source/hooks/common_use_cases.md index 19baaced12..6f5547c2e2 100644 --- a/docs/source/hooks/common_use_cases.md +++ b/docs/source/hooks/common_use_cases.md @@ -97,21 +97,39 @@ We recommend using the `before_dataset_loaded`/`after_dataset_loaded` and `befor For example, you can add logging about the dataset load runtime as follows: ```python -@property -def _logger(self): - return logging.getLogger(self.__class__.__name__) +import logging +import time +from typing import Any +from kedro.framework.hooks import hook_impl +from kedro.pipeline.node import Node + + +class LoggingHook: + """A hook that logs how many time it takes to load each dataset.""" -@hook_impl -def before_dataset_loaded(self, dataset_name: str) -> None: - start = time.time() - self._logger.info("Loading dataset %s started at %0.3f", dataset_name, start) + def __init__(self): + self._timers = {} + @property + def _logger(self): + return logging.getLogger(__name__) -@hook_impl -def after_dataset_loaded(self, dataset_name: str, data: Any) -> None: - end = time.time() - self._logger.info("Loading dataset %s ended at %0.3f", dataset_name, end) + @hook_impl + def before_dataset_loaded(self, dataset_name: str, node: Node) -> None: + start = time.time() + self._timers[dataset_name] = start + + @hook_impl + def after_dataset_loaded(self, dataset_name: str, data: Any, node: Node) -> None: + start = self._timers[dataset_name] + end = time.time() + self._logger.info( + "Loading dataset %s before node '%s' takes %0.2f seconds", + dataset_name, + node.name, + end - start, + ) ``` ## Use Hooks to load external credentials diff --git a/kedro/framework/hooks/specs.py b/kedro/framework/hooks/specs.py index df6b350e7a..8f91452c7d 100644 --- a/kedro/framework/hooks/specs.py +++ b/kedro/framework/hooks/specs.py @@ -240,41 +240,45 @@ class DatasetSpecs: """Namespace that defines all specifications for a dataset's lifecycle hooks.""" @hook_spec - def before_dataset_loaded(self, dataset_name: str) -> None: + def before_dataset_loaded(self, dataset_name: str, node: Node) -> None: """Hook to be invoked before a dataset is loaded from the catalog. Args: dataset_name: name of the dataset to be loaded from the catalog. + node: The ``Node`` to run. """ pass @hook_spec - def after_dataset_loaded(self, dataset_name: str, data: Any) -> None: + def after_dataset_loaded(self, dataset_name: str, data: Any, node: Node) -> None: """Hook to be invoked after a dataset is loaded from the catalog. Args: dataset_name: name of the dataset that was loaded from the catalog. data: the actual data that was loaded from the catalog. + node: The ``Node`` to run. """ pass @hook_spec - def before_dataset_saved(self, dataset_name: str, data: Any) -> None: + def before_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None: """Hook to be invoked before a dataset is saved to the catalog. Args: dataset_name: name of the dataset to be saved to the catalog. data: the actual data to be saved to the catalog. + node: The ``Node`` that ran. """ pass @hook_spec - def after_dataset_saved(self, dataset_name: str, data: Any) -> None: + def after_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None: """Hook to be invoked after a dataset is saved in the catalog. Args: dataset_name: name of the dataset that was saved to the catalog. data: the actual data that was saved to the catalog. + node: The ``Node`` that ran. """ pass diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 0ef68ffd96..7a2444cc6d 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -399,9 +399,11 @@ def _run_node_sequential( inputs = {} for name in node.inputs: - hook_manager.hook.before_dataset_loaded(dataset_name=name) + hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node) inputs[name] = catalog.load(name) - hook_manager.hook.after_dataset_loaded(dataset_name=name, data=inputs[name]) + hook_manager.hook.after_dataset_loaded( + dataset_name=name, data=inputs[name], node=node + ) is_async = False @@ -429,9 +431,9 @@ def _run_node_sequential( items = zip(it.cycle(keys), interleave(*streams)) for name, data in items: - hook_manager.hook.before_dataset_saved(dataset_name=name, data=data) + hook_manager.hook.before_dataset_saved(dataset_name=name, data=data, node=node) catalog.save(name, data) - hook_manager.hook.after_dataset_saved(dataset_name=name, data=data) + hook_manager.hook.after_dataset_saved(dataset_name=name, data=data, node=node) return node @@ -444,10 +446,10 @@ def _run_node_async( def _synchronous_dataset_load(dataset_name: str): """Minimal wrapper to ensure Hooks are run synchronously within an asynchronous dataset load.""" - hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name) + hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node) return_ds = catalog.load(dataset_name) hook_manager.hook.after_dataset_loaded( - dataset_name=dataset_name, data=return_ds + dataset_name=dataset_name, data=return_ds, node=node ) return return_ds @@ -471,7 +473,9 @@ def _synchronous_dataset_load(dataset_name: str): future_dataset_mapping = {} for name, data in outputs.items(): - hook_manager.hook.before_dataset_saved(dataset_name=name, data=data) + hook_manager.hook.before_dataset_saved( + dataset_name=name, data=data, node=node + ) future = pool.submit(catalog.save, name, data) future_dataset_mapping[future] = (name, data) @@ -480,5 +484,7 @@ def _synchronous_dataset_load(dataset_name: str): if exception: raise exception name, data = future_dataset_mapping[future] - hook_manager.hook.after_dataset_saved(dataset_name=name, data=data) + hook_manager.hook.after_dataset_saved( + dataset_name=name, data=data, node=node + ) return node diff --git a/tests/framework/session/conftest.py b/tests/framework/session/conftest.py index 5df05077e9..4c40406e7d 100644 --- a/tests/framework/session/conftest.py +++ b/tests/framework/session/conftest.py @@ -296,25 +296,30 @@ def on_pipeline_error( ) @hook_impl - def before_dataset_loaded(self, dataset_name: str) -> None: - logger.info("Before dataset loaded", extra={"dataset_name": dataset_name}) + def before_dataset_loaded(self, dataset_name: str, node: Node) -> None: + logger.info( + "Before dataset loaded", extra={"dataset_name": dataset_name, "node": node} + ) @hook_impl - def after_dataset_loaded(self, dataset_name: str, data: Any) -> None: + def after_dataset_loaded(self, dataset_name: str, data: Any, node: Node) -> None: logger.info( - "After dataset loaded", extra={"dataset_name": dataset_name, "data": data} + "After dataset loaded", + extra={"dataset_name": dataset_name, "data": data, "node": node}, ) @hook_impl - def before_dataset_saved(self, dataset_name: str, data: Any) -> None: + def before_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None: logger.info( - "Before dataset saved", extra={"dataset_name": dataset_name, "data": data} + "Before dataset saved", + extra={"dataset_name": dataset_name, "data": data, "node": node}, ) @hook_impl - def after_dataset_saved(self, dataset_name: str, data: Any) -> None: + def after_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None: logger.info( - "After dataset saved", extra={"dataset_name": dataset_name, "data": data} + "After dataset saved", + extra={"dataset_name": dataset_name, "data": data, "node": node}, ) @hook_impl diff --git a/tests/framework/session/test_session_extension_hooks.py b/tests/framework/session/test_session_extension_hooks.py index a1ca88bd80..75f3568a5e 100644 --- a/tests/framework/session/test_session_extension_hooks.py +++ b/tests/framework/session/test_session_extension_hooks.py @@ -523,21 +523,20 @@ def test_broken_input_update_parallel( mock_session_with_broken_before_node_run_hooks.run(runner=ParallelRunner()) +def wait_and_identity(*args: Any): + time.sleep(0.1) + if len(args) == 1: + return args[0] + return args + + @pytest.fixture def sample_node(): - def wait_and_identity(x: Any): - time.sleep(0.1) - return x - return node(wait_and_identity, inputs="ds1", outputs="ds2", name="test-node") @pytest.fixture def sample_node_multiple_outputs(): - def wait_and_identity(x: Any, y: Any): - time.sleep(0.1) - return (x, y) - return node( wait_and_identity, inputs=["ds1", "ds2"], @@ -611,8 +610,16 @@ def test_after_dataset_load_hook_async_multiple_outputs( after_dataset_saved_mock.assert_has_calls( [ - mocker.call(dataset_name="ds3", data={"data": 42}), - mocker.call(dataset_name="ds4", data={"data": 42}), + mocker.call( + dataset_name="ds3", + data={"data": 42}, + node=sample_node_multiple_outputs, + ), + mocker.call( + dataset_name="ds4", + data={"data": 42}, + node=sample_node_multiple_outputs, + ), ], any_order=True, )