Skip to content

Commit

Permalink
Add node as an argument to all datasets' hook (#2296)
Browse files Browse the repository at this point in the history
* Add node as an argument to all datasets' hook

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Fix missed node argument and conftest

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Fix the tests, move the inner function out to make the node serializable

Signed-off-by: Nok <nok_lam_chan@mckinsey.com>

* Add docstring

Signed-off-by: Nok <nok_lam_chan@mckinsey.com>

* Fixed example and update releaes notes

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* Fix docstring

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>

* More linting

Signed-off-by: Nok <nok_lam_chan@mckinsey.com>

---------

Signed-off-by: Nok Chan <nok.lam.chan@quantumblack.com>
Signed-off-by: Nok <nok_lam_chan@mckinsey.com>
  • Loading branch information
noklam committed Feb 13, 2023
1 parent a34a2eb commit fa8c56f
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 41 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* Save node outputs after every `yield` before proceeding with next chunk.
* Fixed incorrect parsing of Azure Data Lake Storage Gen2 URIs used in datasets.
* Added support for loading credentials from environment variables using `OmegaConfigLoader`.
* Added a new argument `node` for all four dataset hooks.

## Bug fixes and other changes
* Commas surrounded by square brackets (only possible for nodes with default names) will no longer split the arguments to `kedro run` options which take a list of nodes as inputs (`--from-nodes` and `--to-nodes`).
Expand Down
40 changes: 29 additions & 11 deletions docs/source/hooks/common_use_cases.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,39 @@ We recommend using the `before_dataset_loaded`/`after_dataset_loaded` and `befor
For example, you can add logging about the dataset load runtime as follows:

```python
@property
def _logger(self):
return logging.getLogger(self.__class__.__name__)
import logging
import time
from typing import Any

from kedro.framework.hooks import hook_impl
from kedro.pipeline.node import Node


class LoggingHook:
"""A hook that logs how many time it takes to load each dataset."""

@hook_impl
def before_dataset_loaded(self, dataset_name: str) -> None:
start = time.time()
self._logger.info("Loading dataset %s started at %0.3f", dataset_name, start)
def __init__(self):
self._timers = {}

@property
def _logger(self):
return logging.getLogger(__name__)

@hook_impl
def after_dataset_loaded(self, dataset_name: str, data: Any) -> None:
end = time.time()
self._logger.info("Loading dataset %s ended at %0.3f", dataset_name, end)
@hook_impl
def before_dataset_loaded(self, dataset_name: str, node: Node) -> None:
start = time.time()
self._timers[dataset_name] = start

@hook_impl
def after_dataset_loaded(self, dataset_name: str, data: Any, node: Node) -> None:
start = self._timers[dataset_name]
end = time.time()
self._logger.info(
"Loading dataset %s before node '%s' takes %0.2f seconds",
dataset_name,
node.name,
end - start,
)
```

## Use Hooks to load external credentials
Expand Down
12 changes: 8 additions & 4 deletions kedro/framework/hooks/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,41 +240,45 @@ class DatasetSpecs:
"""Namespace that defines all specifications for a dataset's lifecycle hooks."""

@hook_spec
def before_dataset_loaded(self, dataset_name: str) -> None:
def before_dataset_loaded(self, dataset_name: str, node: Node) -> None:
"""Hook to be invoked before a dataset is loaded from the catalog.
Args:
dataset_name: name of the dataset to be loaded from the catalog.
node: The ``Node`` to run.
"""
pass

@hook_spec
def after_dataset_loaded(self, dataset_name: str, data: Any) -> None:
def after_dataset_loaded(self, dataset_name: str, data: Any, node: Node) -> None:
"""Hook to be invoked after a dataset is loaded from the catalog.
Args:
dataset_name: name of the dataset that was loaded from the catalog.
data: the actual data that was loaded from the catalog.
node: The ``Node`` to run.
"""
pass

@hook_spec
def before_dataset_saved(self, dataset_name: str, data: Any) -> None:
def before_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None:
"""Hook to be invoked before a dataset is saved to the catalog.
Args:
dataset_name: name of the dataset to be saved to the catalog.
data: the actual data to be saved to the catalog.
node: The ``Node`` that ran.
"""
pass

@hook_spec
def after_dataset_saved(self, dataset_name: str, data: Any) -> None:
def after_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None:
"""Hook to be invoked after a dataset is saved in the catalog.
Args:
dataset_name: name of the dataset that was saved to the catalog.
data: the actual data that was saved to the catalog.
node: The ``Node`` that ran.
"""
pass

Expand Down
22 changes: 14 additions & 8 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,11 @@ def _run_node_sequential(
inputs = {}

for name in node.inputs:
hook_manager.hook.before_dataset_loaded(dataset_name=name)
hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node)
inputs[name] = catalog.load(name)
hook_manager.hook.after_dataset_loaded(dataset_name=name, data=inputs[name])
hook_manager.hook.after_dataset_loaded(
dataset_name=name, data=inputs[name], node=node
)

is_async = False

Expand Down Expand Up @@ -429,9 +431,9 @@ def _run_node_sequential(
items = zip(it.cycle(keys), interleave(*streams))

for name, data in items:
hook_manager.hook.before_dataset_saved(dataset_name=name, data=data)
hook_manager.hook.before_dataset_saved(dataset_name=name, data=data, node=node)
catalog.save(name, data)
hook_manager.hook.after_dataset_saved(dataset_name=name, data=data)
hook_manager.hook.after_dataset_saved(dataset_name=name, data=data, node=node)
return node


Expand All @@ -444,10 +446,10 @@ def _run_node_async(
def _synchronous_dataset_load(dataset_name: str):
"""Minimal wrapper to ensure Hooks are run synchronously
within an asynchronous dataset load."""
hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name)
hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node)
return_ds = catalog.load(dataset_name)
hook_manager.hook.after_dataset_loaded(
dataset_name=dataset_name, data=return_ds
dataset_name=dataset_name, data=return_ds, node=node
)
return return_ds

Expand All @@ -471,7 +473,9 @@ def _synchronous_dataset_load(dataset_name: str):

future_dataset_mapping = {}
for name, data in outputs.items():
hook_manager.hook.before_dataset_saved(dataset_name=name, data=data)
hook_manager.hook.before_dataset_saved(
dataset_name=name, data=data, node=node
)
future = pool.submit(catalog.save, name, data)
future_dataset_mapping[future] = (name, data)

Expand All @@ -480,5 +484,7 @@ def _synchronous_dataset_load(dataset_name: str):
if exception:
raise exception
name, data = future_dataset_mapping[future]
hook_manager.hook.after_dataset_saved(dataset_name=name, data=data)
hook_manager.hook.after_dataset_saved(
dataset_name=name, data=data, node=node
)
return node
21 changes: 13 additions & 8 deletions tests/framework/session/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,25 +296,30 @@ def on_pipeline_error(
)

@hook_impl
def before_dataset_loaded(self, dataset_name: str) -> None:
logger.info("Before dataset loaded", extra={"dataset_name": dataset_name})
def before_dataset_loaded(self, dataset_name: str, node: Node) -> None:
logger.info(
"Before dataset loaded", extra={"dataset_name": dataset_name, "node": node}
)

@hook_impl
def after_dataset_loaded(self, dataset_name: str, data: Any) -> None:
def after_dataset_loaded(self, dataset_name: str, data: Any, node: Node) -> None:
logger.info(
"After dataset loaded", extra={"dataset_name": dataset_name, "data": data}
"After dataset loaded",
extra={"dataset_name": dataset_name, "data": data, "node": node},
)

@hook_impl
def before_dataset_saved(self, dataset_name: str, data: Any) -> None:
def before_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None:
logger.info(
"Before dataset saved", extra={"dataset_name": dataset_name, "data": data}
"Before dataset saved",
extra={"dataset_name": dataset_name, "data": data, "node": node},
)

@hook_impl
def after_dataset_saved(self, dataset_name: str, data: Any) -> None:
def after_dataset_saved(self, dataset_name: str, data: Any, node: Node) -> None:
logger.info(
"After dataset saved", extra={"dataset_name": dataset_name, "data": data}
"After dataset saved",
extra={"dataset_name": dataset_name, "data": data, "node": node},
)

@hook_impl
Expand Down
27 changes: 17 additions & 10 deletions tests/framework/session/test_session_extension_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,21 +523,20 @@ def test_broken_input_update_parallel(
mock_session_with_broken_before_node_run_hooks.run(runner=ParallelRunner())


def wait_and_identity(*args: Any):
time.sleep(0.1)
if len(args) == 1:
return args[0]
return args


@pytest.fixture
def sample_node():
def wait_and_identity(x: Any):
time.sleep(0.1)
return x

return node(wait_and_identity, inputs="ds1", outputs="ds2", name="test-node")


@pytest.fixture
def sample_node_multiple_outputs():
def wait_and_identity(x: Any, y: Any):
time.sleep(0.1)
return (x, y)

return node(
wait_and_identity,
inputs=["ds1", "ds2"],
Expand Down Expand Up @@ -611,8 +610,16 @@ def test_after_dataset_load_hook_async_multiple_outputs(

after_dataset_saved_mock.assert_has_calls(
[
mocker.call(dataset_name="ds3", data={"data": 42}),
mocker.call(dataset_name="ds4", data={"data": 42}),
mocker.call(
dataset_name="ds3",
data={"data": 42},
node=sample_node_multiple_outputs,
),
mocker.call(
dataset_name="ds4",
data={"data": 42},
node=sample_node_multiple_outputs,
),
],
any_order=True,
)
Expand Down

0 comments on commit fa8c56f

Please sign in to comment.