Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allows layering of subdag and extract_* #120

Merged
merged 2 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import itertools
import logging
from abc import ABC

try:
from types import EllipsisType
Expand All @@ -15,7 +16,6 @@

logger = logging.getLogger(__name__)


if not registry.INITIALIZED:
# Trigger load of extensions here because decorators are the only thing that use the registry
# right now. Side note: ray serializes things weirdly, so we need to do this here rather than in
Expand Down Expand Up @@ -273,6 +273,34 @@ def allows_multiple(cls) -> bool:


class NodeTransformer(SubDAGModifier):
@classmethod
def _early_validate_target(cls, target: TargetType, allow_multiple: bool):
"""Determines whether the target is valid, given that we may or may not
want to allow multiple nodes to be transformed.

If the target type is a single string then we're good.
If the target type is a collection of strings, then it has to be a collection of size one.
If the target type is None, then we delay checking until later (as there might be just
one node transformed in the DAG).
If the target type is ellipsis, then we delay checking until later (as there might be
just one node transformed in the DAG)

:param target: How to appply this node. See docs below.
:param allow_multiple: Whether or not this can operate on multiple nodes.
:raises InvalidDecoratorException: if the target is invalid given the value of allow_multiple.
"""
if isinstance(target, str):
# We're good -- regardless of the value of allow_multiple we'll pass
return
elif isinstance(target, Collection) and all(isinstance(x, str) for x in target):
if len(target) > 1 and not allow_multiple:
raise InvalidDecoratorException(f"Cannot have multiple targets for . Got {target}")
return
elif target is None or target is Ellipsis:
return
else:
raise InvalidDecoratorException(f"Invalid target type for NodeTransformer: {target}")

def __init__(self, target: TargetType):
"""Target determines to which node(s) this applies. This represents selection from a subDAG.
For the options, consider at the following graph:
Expand Down Expand Up @@ -357,6 +385,25 @@ def compliment(
"""
return [node_ for node_ in all_nodes if node_ not in nodes_to_transform]

def transform_targets(
self, targets: Collection[node.Node], config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""Transforms a set of target nodes. Note that this is just a loop,
but abstracting t away gives subclasses control over how this is done,
allowing them to validate beforehand. While we *could* just have this
as a `validate`, or `transforms_multiple` function, this is a pretty clean/
readable way to do it.

:param targets: Node Targets to transform
:param config: Configuration to use to
:param fn: Function being decorated
:return: Results of transformations
"""
out = []
for node_to_transform in targets:
out += list(self.transform_node(node_to_transform, config, fn))
return out

def transform_dag(
self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
Expand All @@ -371,8 +418,7 @@ def transform_dag(
nodes_to_transform = self.select_nodes(self.target, nodes)
nodes_to_keep = self.compliment(nodes, nodes_to_transform)
out = list(nodes_to_keep)
for node_to_transform in nodes_to_transform:
out += list(self.transform_node(node_to_transform, config, fn))
out += self.transform_targets(nodes_to_transform, config, fn)
return out

@abc.abstractmethod
Expand All @@ -394,6 +440,39 @@ def allows_multiple(cls) -> bool:
return True


class SingleNodeNodeTransformer(NodeTransformer, ABC):
"""A node transformer that only allows a single node to be transformed.
Specifically, this must be applied to a decorator operation that returns
a single node (E.G. @subdag). Note that if you have multiple node transformations,
the order *does* matter.

This should end up killing NodeExpander, as it has the same impact, and the same API.
"""

def __init__(self):
"""Initializes the node transformer to only allow a single node to be transformed.
Note this passes target=None to the superclass, which means that it will only
apply to the 'sink' nodes produced."""
super().__init__(target=None)

def transform_targets(
self, targets: Collection[node.Node], config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""Transforms the target set of nodes. Exists to validate the target set.

:param targets: Targets to transform -- this has to be an array of 1.
:param config: Configuration passed into the DAG.
:param fn: Function that was decorated.
:return: The resulting nodes.
"""
if len(targets) != 1:
raise InvalidDecoratorException(
f"Expected a single node to transform, but got {len(targets)}. {self.__class__} "
f" can only operate on a single node, but multiple nodes were created by {fn.__qualname__}"
)
return super().transform_targets(targets, config, fn)


class NodeDecorator(NodeTransformer, abc.ABC):
DECORATE_NODES = "decorate_nodes"

Expand Down
19 changes: 11 additions & 8 deletions hamilton/function_modifiers/expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ class parameterized_inputs(parameterize_sources):
pass


class extract_columns(base.NodeExpander):
class extract_columns(base.SingleNodeNodeTransformer):
def __init__(self, *columns: Union[Tuple[str, str], str], fill_with: Any = None):
"""Constructor for a modifier that expands a single function into the following nodes:

Expand All @@ -577,6 +577,7 @@ def __init__(self, *columns: Union[Tuple[str, str], str], fill_with: Any = None)
value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \
column.
"""
super(extract_columns, self).__init__()
if not columns:
raise base.InvalidDecoratorException(
"Error empty arguments passed to extract_columns decorator."
Expand All @@ -599,7 +600,8 @@ def validate_return_type(fn: Callable):
except NotImplementedError:
raise base.InvalidDecoratorException(
# TODO: capture was dataframe libraries are supported and print here.
f"Error {fn} does not output a type we know about. Is it a dataframe type we support?"
f"Error {fn} does not output a type we know about. Is it a dataframe type we "
f"support? "
)

def validate(self, fn: Callable):
Expand All @@ -610,13 +612,13 @@ def validate(self, fn: Callable):
"""
extract_columns.validate_return_type(fn)

def expand_node(
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""For each column to extract, output a node that extracts that column. Also, output the original dataframe
generator.

:param config:
:param node_: Node to transform
:param config: Config to use
:param fn: Function to extract columns from. Must output a dataframe.
:return: A collection of nodes --
one for the original dataframe generator, and another for each column to extract.
Expand Down Expand Up @@ -692,7 +694,7 @@ def extractor_fn(
return output_nodes


class extract_fields(base.NodeExpander):
class extract_fields(base.SingleNodeNodeTransformer):
"""Extracts fields from a dictionary of output."""

def __init__(self, fields: dict, fill_with: Any = None):
Expand All @@ -706,6 +708,7 @@ def __init__(self, fields: dict, fill_with: Any = None):
value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \
field value.
"""
super(extract_fields, self).__init__()
if not fields:
raise base.InvalidDecoratorException(
"Error an empty dict, or no dict, passed to extract_fields decorator."
Expand Down Expand Up @@ -755,7 +758,7 @@ def validate(self, fn: Callable):
f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}"
)

def expand_node(
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
"""For each field to extract, output a node that extracts that field. Also, output the original TypedDict
Expand Down Expand Up @@ -924,7 +927,7 @@ def wrapper_fn(*args, _output_columns=parameterization.outputs, **kwargs):
)
extract_columns_decorator = extract_columns(*parameterization.outputs)
output_nodes.extend(
extract_columns_decorator.expand_node(
extract_columns_decorator.transform_node(
parameterized_node, config, parameterized_node.callable
)
)
Expand Down
29 changes: 25 additions & 4 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame:

E.G. take a certain set of nodes, and run them with specified parameters.

@subdag declares parameters that are outputs of its subdags. Note that, if you want to use outputs of other
components of the DAG, you can use the `external_inputs` parameter to declare the parameters that do *not* come
from the subDAG.

Why might you want to use this? Let's take a look at some examples:

1. You have a feature engineering pipeline that you want to run on multiple datasets. If its exactly the same, \
Expand All @@ -148,6 +152,7 @@ def __init__(
config: Dict[str, Any] = None,
namespace: str = None,
final_node_name: str = None,
external_inputs: List[str] = None,
):
"""Adds a subDAG to the main DAG.

Expand All @@ -160,10 +165,14 @@ def __init__(
this will default to the function name.
:param final_node_name: Name of the final node in the subDAG. This is optional -- if not included,
this will default to the function name.
:param external_inputs: Parameters in the function that are not produced by the functions
passed to the subdag. This is useful if you want to perform some logic with other inputs
in the subdag's processing function.
"""
self.subdag_functions = subdag.collect_functions(load_from)
self.inputs = inputs if inputs is not None else {}
self.config = config if config is not None else {}
self.external_inputs = external_inputs if external_inputs is not None else []
self._validate_config_inputs(self.config, self.inputs)
self.namespace = namespace
self.final_node_name = final_node_name
Expand Down Expand Up @@ -307,9 +316,14 @@ def add_final_node(self, fn: Callable, node_name: str, namespace: str):
:return:
"""
node_ = node.Node.from_fn(fn)
namespaced_input_map = {assign_namespace(key, namespace): key for key in node_.input_types}
namespaced_input_map = {
(assign_namespace(key, namespace) if key not in self.external_inputs else key): key
for key in node_.input_types
}

new_input_types = {
assign_namespace(key, namespace): value for key, value in node_.input_types.items()
(assign_namespace(key, namespace) if key not in self.external_inputs else key): value
for key, value in node_.input_types.items()
}

def new_function(**kwargs):
Expand Down Expand Up @@ -377,11 +391,12 @@ def validate(self, fn):
class SubdagParams(TypedDict):
inputs: NotRequired[Dict[str, ParametrizedDependency]]
config: NotRequired[Dict[str, Any]]
external_inputs: NotRequired[List[str]]


class parameterized_subdag(base.NodeCreator):
"""parameterized subdag is when you want to create multiple subdags at one time.
Why do you want to do this?
Why might you want to do this?

1. You have multiple data sets you want to run the same feature engineering pipeline on.
2. You want to run some sort of optimization routine with a variety of results
Expand Down Expand Up @@ -444,13 +459,16 @@ def __init__(
str, Union[dependencies.ParametrizedDependency, dependencies.LiteralDependency]
] = None,
config: Dict[str, Any] = None,
external_inputs: List[str] = None,
**parameterization: SubdagParams,
):
"""Initializes a parameterized_subdag decorator.

:param load_from: Modules to load from
:param inputs: Inputs for each subdag generated by the decorated function
:param config: Config for each subdag generated by the decorated function
:param external_inputs: External inputs to all parameterized subdags. Note that
if you pass in any external inputs from local subdags, it overrides this (does not merge).
:param parameterization: Parameterizations for each subdag generated.
Note that this *overrides* any inputs/config passed to the decorator itself.

Expand All @@ -460,12 +478,14 @@ def __init__(
allowed to name these `load_from`, `inputs`, or `config`. That's a good thing, as these
are not good names for variables anyway.

2. Any empty items (not included) will default to an empty dict
2. Any empty items (not included) will default to an empty dict (or an empty list in
the case of parameterization)
"""
self.load_from = load_from
self.inputs = inputs if inputs is not None else {}
self.config = config if config is not None else {}
self.parameterization = parameterization
self.external_inputs = external_inputs if external_inputs is not None else []

def _gather_subdag_generators(self) -> List[subdag]:
subdag_generators = []
Expand All @@ -475,6 +495,7 @@ def _gather_subdag_generators(self) -> List[subdag]:
*self.load_from,
inputs={**self.inputs, **parameterization.get("inputs", {})},
config={**self.config, **parameterization.get("config", {})},
external_inputs=parameterization.get("external_inputs", self.external_inputs),
namespace=key,
final_node_name=key,
)
Expand Down
75 changes: 75 additions & 0 deletions tests/function_modifiers/test_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""A few tests for combining different decorators.
While this should not be necessary -- we should be able to test the decorator lifecycle functions,
it is useful to have a few tests that demonstrate that common use-cases are supported.

Note we also have some more end-to-end cases in test_layered.py"""
from typing import Dict

import pandas as pd

from hamilton.function_modifiers import base as fm_base
from hamilton.function_modifiers import extract_columns, extract_fields, subdag, tag


def test_subdag_and_extract_columns():
def foo() -> pd.Series:
return pd.Series([1, 2, 3])

def bar() -> pd.Series:
return pd.Series([1, 2, 3])

@extract_columns("foo", "bar")
@subdag(foo, bar)
def foo_bar(foo: pd.Series, bar: pd.Series) -> pd.DataFrame:
return pd.DataFrame({"foo": foo, "bar": bar})

nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"]
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]


def test_subdag_and_extract_fields():
def foo() -> int:
return 1

def bar() -> int:
return 2

@extract_fields({"foo": int, "bar": int})
@subdag(foo, bar)
def foo_bar(foo: int, bar: pd.Series) -> Dict[str, int]:
return {"foo": foo, "bar": bar}

nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"]
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]


def test_subdag_and_extract_fields_with_tags():
def foo() -> int:
return 1

def bar() -> int:
return 2

@tag(a="b", target_="foo")
@tag(a="c", target_="bar")
@extract_fields({"foo": int, "bar": int})
@subdag(foo, bar)
def foo_bar(foo: int, bar: pd.Series) -> Dict[str, int]:
return {"foo": foo, "bar": bar}

nodes = fm_base.resolve_nodes(foo_bar, {})
nodes_by_name = {node.name: node for node in nodes}
assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"]
# The extraction columns should depend on the thing from which they are extracted
assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"]
assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"]
assert nodes_by_name["foo"].tags["a"] == "b"
assert nodes_by_name["bar"].tags["a"] == "c"
Loading