Skip to content

Commit

Permalink
plugins Kedro (#916)
Browse files Browse the repository at this point in the history
Introduce a Kedro plugin for Hamilton. This commit includes several changes:

1. `hamilton.plugins.h_kedro` allows to convert Kedro `Pipeline` objects to Hamilton `Driver`. 
   This supports all Hamilton features including tracking with the Hamilton UI.

2. `hamilton.plugins.kedro_extensions` adds materializers that wrap Kedro `Dataset` objects. 
   This allows Kedro users to reuse their already defined Kedro `DataCatalog` with Hamilton
   and use `Dataset` that aren't materializers in Hamilton yet (e.g., Snowpark)

3. bug fix for `@extract_fields` in Python <3.11. The current type check wouldn't allow the
   `Any` type annotation for extracted fields because `isinstance(Any, type) == False` pre 3.11.


Design decisions:
1. In Kedro, the Python function and the node definition are decoupled. The syntax for defining the
    node inputs and outputs and flexible and we need to handle all cases (None, str, list, dict).

    node ref: https://docs.kedro.org/en/stable/nodes_and_pipelines/nodes.html#how-to-create-a-node

    Contrary to Hamilton nodes, Kedro nodes can return more than one value. In the conversion
    process,  we're manually adding `extract_fields` decorators to expand these Kedro nodes. 
    Given no type annotations exists for the extracted nodes, we set `Any` 
    (which relates to the bugfix in change 3.)

    Kedro has the concept of "parameters", which are lightweight values passed at execution time
    typically defined in a YAML file. Those would be "inputs" in Hamilton. In the Kedro node definition,
    parameters are prefixed with `params:`. In the conversion, we remove that prefix.

    The process to create the Hamilton `Driver` is very manual. We must pass the lifecycle adapters
    when creating the `FunctionGraph` from individual nodes and trigger the `post_graph_construct` hook

2. The materializers are quite simple since they wrap the Kedro datasets. The current API is simple
    and expects a `DataCatalog` instance with the dataset name. This works best when users have an 
    existing catalog (likely most users of this plugin). We could also extend this API by allowing users to 
    pass a `Dataset` definition directly reducing the boilerplate if they don't already have a catalog defined.
  
---------

Co-authored-by: zilto <tjean@DESKTOP-V6JDCS2>
  • Loading branch information
zilto and zilto authored May 23, 2024
1 parent 7000259 commit 96da310
Show file tree
Hide file tree
Showing 11 changed files with 2,125 additions and 2 deletions.
3 changes: 2 additions & 1 deletion examples/kedro/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This repository compares how to build dataflows with Kedro and Hamilton.
## Content
- `kedro-code/` includes code from the [Kedro Spaceflight tutorial](https://docs.kedro.org/en/stable/tutorial/tutorial_template.html).
- `hamilton-code/` is a refactor of `kedro-code/` using the Hamilton framework.
- `hamilton-code/` is a refactor of `kedro-code/` using the Hamilton library.
- `kedro-plugin/` showcases Hamilton plugins to integrate with the Kedro framework.

Each directory contains a `README` with instructions on how to run the code. We suggest going through the Kedro code first, and then read the Hamilton refactor.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def train_model(X_train: pd.DataFrame, y_train: pd.Series) -> LinearRegression:
return regressor


def evaluate_model(regressor: LinearRegression, X_test: pd.DataFrame, y_test: pd.Series):
def evaluate_model(regressor: LinearRegression, X_test: pd.DataFrame, y_test: pd.Series) -> None:
"""Calculates and logs the coefficient of determination.
Args:
Expand Down
4 changes: 4 additions & 0 deletions examples/kedro/kedro-plugin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Kedro plugin

## Content
- `kedro_to_hamilton.ipynb` contains a tutorial on how to execute your Kedro `Pipeline` using Hamilton, track your execution in the Hamilton UI, and use the Kedro materializers to load & save data.
1,795 changes: 1,795 additions & 0 deletions examples/kedro/kedro-plugin/kedro_to_hamilton.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"vaex",
"ibis",
"dlt",
"kedro",
"huggingface",
]
for plugin_module in plugins_modules:
Expand Down
2 changes: 2 additions & 0 deletions hamilton/function_modifiers/expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,10 @@ def _validate_extract_fields(fields: dict):
if not isinstance(field, str):
errors.append(f"{field} is not a string. All keys must be strings.")

# second condition needed because isinstance(Any, type) == False for Python <3.11
if not (
isinstance(field_type, type)
or field_type is Any
or typing_inspect.is_generic_type(field_type)
or typing_inspect.is_union_type(field_type)
):
Expand Down
135 changes: 135 additions & 0 deletions hamilton/plugins/h_kedro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import inspect
from typing import Any, Dict, List, Optional, Tuple, Type

from kedro.pipeline.node import Node as KNode
from kedro.pipeline.pipeline import Pipeline as KPipeline

from hamilton import driver, graph
from hamilton.function_modifiers.expanders import extract_fields
from hamilton.lifecycle import base as lifecycle_base
from hamilton.node import Node as HNode


def expand_k_node(base_node: HNode, outputs: List[str]) -> List[HNode]:
"""Manually apply `@extract_fields()` on a Hamilton node.Node for a Kedro
node that specifies >1 `outputs`.
The number of nodes == len(outputs) + 1 because it includes the `base_node`
"""

def _convert_output_from_tuple_to_dict(node_result: Any, node_kwargs: Dict[str, Any]):
return {out: v for out, v in zip(outputs, node_result)}

# NOTE isinstance(Any, type) is False for Python < 3.11
extractor = extract_fields(fields={out: Any for out in outputs})
func = base_node.originating_functions[0]
if issubclass(func.__annotations__["return"], Tuple):
base_node = base_node.transform_output(_convert_output_from_tuple_to_dict, Dict)
func.__annotations__["return"] = Dict

extractor.validate(func)
return list(extractor.transform_node(base_node, {}, func))


def k_node_to_h_nodes(node: KNode) -> List[HNode]:
"""Convert a Kedro node to a list of Hamilton nodes.
If the Kedro node specifies 1 output, generate 1 Hamilton node.
If it generate >1 output, generate len(outputs) + 1 to include the base node + extracted fields.
"""
# determine if more than one output
node_names = []
if isinstance(node.outputs, list):
node_names.extend(node.outputs)
elif isinstance(node.outputs, dict):
node_names.extend(node.outputs.values())

# determine the base node name
if len(node_names) == 1:
base_node_name = node_names[0]
elif isinstance(node.outputs, str):
base_node_name = node.outputs
else:
base_node_name = node.func.__name__

func_sig = inspect.signature(node.func)
params = func_sig.parameters.values()
output_type = func_sig.return_annotation
if output_type is None:
# manually creating `hamilton.node.Node` doesn't accept `typ=None`
output_type = Type[None] # NoneType is introduced in Python 3.10

base_node = HNode(
name=base_node_name,
typ=output_type,
doc_string=getattr(node.func, "__doc__", ""),
callabl=node.func,
originating_functions=(node.func,),
)

# if Kedro node defines multiple outputs, use `@extract_fields()`
if len(node_names) > 1:
h_nodes = expand_k_node(base_node, node_names)
else:
h_nodes = [base_node]

# remap the function parameters to the node `inputs` and clean Kedro `parameters` name
new_params = {}
for param, k_input in zip(params, node.inputs):
if k_input.startswith("params:"):
k_input = k_input.partition("params:")[-1]

new_params[param.name] = k_input

h_nodes = [n.reassign_inputs(input_names=new_params) for n in h_nodes]

return h_nodes


def kedro_pipeline_to_driver(
*pipelines: KPipeline,
builder: Optional[driver.Builder] = None,
) -> driver.Driver:
"""Convert one or mode Kedro `Pipeline` to a Hamilton `Driver`.
Pass a Hamilton `Builder` to include lifecycle adapters in your `Driver`.
:param pipelines: one or more Kedro `Pipeline` objects
:param builder: a Hamilton `Builder` to use when building the `Driver`
:return: the Hamilton `Driver` built from Kedro `Pipeline` objects.
.. code-block: python
from hamilton import driver
from hamilton.plugins import h_kedro
builder = driver.Builder().with_adapters(tracker)
dr = h_kedro.kedro_pipeline_to_driver(
data_science.create_pipeline(), # Kedro Pipeline
data_processing.create_pipeline(), # Kedro Pipeline
builder=builder
)
"""
# generate nodes
h_nodes = []
for pipe in pipelines:
for node in pipe.nodes:
h_nodes.extend(k_node_to_h_nodes(node))

# resolve dependencies
h_nodes = graph.update_dependencies(
{n.name: n for n in h_nodes},
lifecycle_base.LifecycleAdapterSet(),
)

builder = builder if builder else driver.Builder()
dr = builder.build()
# inject function graph in Driver
dr.graph = graph.FunctionGraph(
h_nodes, config={}, adapter=lifecycle_base.LifecycleAdapterSet(*builder.adapters)
)
# reapply lifecycle hooks
if dr.adapter.does_hook("post_graph_construct", is_async=False):
dr.adapter.call_all_lifecycle_hooks_sync(
"post_graph_construct", graph=dr.graph, modules=dr.graph_modules, config={}
)
return dr
99 changes: 99 additions & 0 deletions hamilton/plugins/kedro_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import dataclasses
from typing import Any, Collection, Dict, Optional, Tuple, Type

from kedro.io import DataCatalog

from hamilton import registry
from hamilton.io.data_adapters import DataLoader, DataSaver


@dataclasses.dataclass
class KedroSaver(DataSaver):
"""Use Kedro DataCatalog and Dataset to save results
ref: https://docs.kedro.org/en/stable/data/advanced_data_catalog_usage.html
.. code-block:: python
from kedro.framework.session import KedroSession
with KedroSession.create() as session:
context = session.load_context()
catalog = context.catalog
dr.materialize(
to.kedro(
id="my_dataset__kedro",
dependencies=["my_dataset"],
dataset_name="my_dataset",
catalog=catalog
)
)
"""

dataset_name: str
catalog: DataCatalog

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [Any]

def save_data(self, data: Any) -> Dict[str, Any]:
self.catalog.save(name=self.dataset_name, data=data)
return dict(success=True)

@classmethod
def name(cls) -> str:
return "kedro"


@dataclasses.dataclass
class KedroLoader(DataLoader):
"""Use Kedro DataCatalog and Dataset to load data
ref: https://docs.kedro.org/en/stable/data/advanced_data_catalog_usage.html
.. code-block:: python
from kedro.framework.session import KedroSession
with KedroSession.create() as session:
context = session.load_context()
catalog = context.catalog
dr.materialize(
from_.kedro(
target="input_table",
dataset_name="input_table",
catalog=catalog
)
)
"""

dataset_name: str
catalog: DataCatalog
version: Optional[str] = None

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [Any]

def load_data(self, type_: Type) -> Tuple[Any, Dict[str, Any]]:
data = self.catalog.load(name=self.dataset_name, version=self.version)
metadata = dict(dataset_name=self.dataset_name, version=self.version)
return data, metadata

@classmethod
def name(cls) -> str:
return "kedro"


def register_data_loaders():
for loader in [
KedroSaver,
KedroLoader,
]:
registry.register_adapter(loader)


register_data_loaders()

COLUMN_FRIENDLY_DF_TYPE = False
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dlt
fsspec
graphviz
kaleido
kedro
lancedb
lightgbm
lxml
Expand Down
58 changes: 58 additions & 0 deletions tests/plugins/test_h_kedro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import inspect

import pandas as pd
from kedro.pipeline import node

from hamilton.plugins import h_kedro


def test_parse_k_node_str_output():
def preprocess_companies(companies: pd.DataFrame) -> pd.DataFrame:
"""Preprocesses the data for companies."""
companies["iata_approved"] = companies["iata_approved"].astype("category")
return companies

kedro_node = node(
func=preprocess_companies,
inputs="companies",
outputs="preprocessed_companies",
name="preprocess_companies_node",
)
h_nodes = h_kedro.k_node_to_h_nodes(kedro_node)
assert len(h_nodes) == 1
assert h_nodes[0].name == "preprocessed_companies"
assert h_nodes[0].type == inspect.signature(preprocess_companies).return_annotation


def test_parse_k_node_list_outputs():
def multi_outputs() -> dict:
return dict(a=1, b=2)

kedro_node = node(
func=multi_outputs,
inputs=None,
outputs=["a", "b"],
)
h_nodes = h_kedro.k_node_to_h_nodes(kedro_node)
node_names = [n.name for n in h_nodes]
assert len(h_nodes) == 3
assert "multi_outputs" in node_names
assert "a" in node_names
assert "b" in node_names


def test_parse_k_node_dict_outputs():
def multi_outputs() -> dict:
return dict(a=1, b=2)

kedro_node = node(
func=multi_outputs,
inputs=None,
outputs={"a": "a", "b": "b"},
)
h_nodes = h_kedro.k_node_to_h_nodes(kedro_node)
node_names = [n.name for n in h_nodes]
assert len(h_nodes) == 3
assert "multi_outputs" in node_names
assert "a" in node_names
assert "b" in node_names
27 changes: 27 additions & 0 deletions tests/plugins/test_kedro_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from kedro.io import DataCatalog
from kedro.io.memory_dataset import MemoryDataset

from hamilton.plugins import kedro_extensions


def test_kedro_saver():
dataset_name = "in_memory"
data = 37
catalog = DataCatalog({dataset_name: MemoryDataset()})

saver = kedro_extensions.KedroSaver(dataset_name=dataset_name, catalog=catalog)
saver.save_data(data)
loaded_data = catalog.load(dataset_name)

assert loaded_data == data


def test_kedro_loader():
dataset_name = "in_memory"
data = 37
catalog = DataCatalog({dataset_name: MemoryDataset(data=data)})

loader = kedro_extensions.KedroLoader(dataset_name=dataset_name, catalog=catalog)
loaded_data, metadata = loader.load_data(int)

assert loaded_data == data

0 comments on commit 96da310

Please sign in to comment.