From eb9e04a8101c661b9fbd73c0e362c5474fc847f7 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 22 Mar 2023 14:01:44 -0700 Subject: [PATCH 1/2] Allows layering of subdag and extract_* This solves #72. This is done by introducing a new base class decorator SingleNodeNodeTransformer that extends from NodeTransformer. This is meant to replace `NodeExpander`, as it allows transformers to run on a subdag that outputs a single "final" node (a sink). Note that, in this case, decorator ordering *will* matter. --- hamilton/function_modifiers/base.py | 85 +++++++++++++++++++++- hamilton/function_modifiers/expanders.py | 19 +++-- tests/function_modifiers/test_combined.py | 75 +++++++++++++++++++ tests/function_modifiers/test_expanders.py | 12 +-- 4 files changed, 174 insertions(+), 17 deletions(-) create mode 100644 tests/function_modifiers/test_combined.py diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index a60033b7b..e88cd19f9 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -3,6 +3,7 @@ import functools import itertools import logging +from abc import ABC try: from types import EllipsisType @@ -15,7 +16,6 @@ logger = logging.getLogger(__name__) - if not registry.INITIALIZED: # Trigger load of extensions here because decorators are the only thing that use the registry # right now. Side note: ray serializes things weirdly, so we need to do this here rather than in @@ -273,6 +273,34 @@ def allows_multiple(cls) -> bool: class NodeTransformer(SubDAGModifier): + @classmethod + def _early_validate_target(cls, target: TargetType, allow_multiple: bool): + """Determines whether the target is valid, given that we may or may not + want to allow multiple nodes to be transformed. + + If the target type is a single string then we're good. + If the target type is a collection of strings, then it has to be a collection of size one. + If the target type is None, then we delay checking until later (as there might be just + one node transformed in the DAG). + If the target type is ellipsis, then we delay checking until later (as there might be + just one node transformed in the DAG) + + :param target: How to appply this node. See docs below. + :param allow_multiple: Whether or not this can operate on multiple nodes. + :raises InvalidDecoratorException: if the target is invalid given the value of allow_multiple. + """ + if isinstance(target, str): + # We're good -- regardless of the value of allow_multiple we'll pass + return + elif isinstance(target, Collection) and all(isinstance(x, str) for x in target): + if len(target) > 1 and not allow_multiple: + raise InvalidDecoratorException(f"Cannot have multiple targets for . Got {target}") + return + elif target is None or target is Ellipsis: + return + else: + raise InvalidDecoratorException(f"Invalid target type for NodeTransformer: {target}") + def __init__(self, target: TargetType): """Target determines to which node(s) this applies. This represents selection from a subDAG. For the options, consider at the following graph: @@ -357,6 +385,25 @@ def compliment( """ return [node_ for node_ in all_nodes if node_ not in nodes_to_transform] + def transform_targets( + self, targets: Collection[node.Node], config: Dict[str, Any], fn: Callable + ) -> Collection[node.Node]: + """Transforms a set of target nodes. Note that this is just a loop, + but abstracting t away gives subclasses control over how this is done, + allowing them to validate beforehand. While we *could* just have this + as a `validate`, or `transforms_multiple` function, this is a pretty clean/ + readable way to do it. + + :param targets: Node Targets to transform + :param config: Configuration to use to + :param fn: Function being decorated + :return: Results of transformations + """ + out = [] + for node_to_transform in targets: + out += list(self.transform_node(node_to_transform, config, fn)) + return out + def transform_dag( self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: @@ -371,8 +418,7 @@ def transform_dag( nodes_to_transform = self.select_nodes(self.target, nodes) nodes_to_keep = self.compliment(nodes, nodes_to_transform) out = list(nodes_to_keep) - for node_to_transform in nodes_to_transform: - out += list(self.transform_node(node_to_transform, config, fn)) + out += self.transform_targets(nodes_to_transform, config, fn) return out @abc.abstractmethod @@ -394,6 +440,39 @@ def allows_multiple(cls) -> bool: return True +class SingleNodeNodeTransformer(NodeTransformer, ABC): + """A node transformer that only allows a single node to be transformed. + Specifically, this must be applied to a decorator operation that returns + a single node (E.G. @subdag). Note that if you have multiple node transformations, + the order *does* matter. + + This should end up killing NodeExpander, as it has the same impact, and the same API. + """ + + def __init__(self): + """Initializes the node transformer to only allow a single node to be transformed. + Note this passes target=None to the superclass, which means that it will only + apply to the 'sink' nodes produced.""" + super().__init__(target=None) + + def transform_targets( + self, targets: Collection[node.Node], config: Dict[str, Any], fn: Callable + ) -> Collection[node.Node]: + """Transforms the target set of nodes. Exists to validate the target set. + + :param targets: Targets to transform -- this has to be an array of 1. + :param config: Configuration passed into the DAG. + :param fn: Function that was decorated. + :return: The resulting nodes. + """ + if len(targets) != 1: + raise InvalidDecoratorException( + f"Expected a single node to transform, but got {len(targets)}. {self.__class__} " + f" can only operate on a single node, but multiple nodes were created by {fn.__qualname__}" + ) + return super().transform_targets(targets, config, fn) + + class NodeDecorator(NodeTransformer, abc.ABC): DECORATE_NODES = "decorate_nodes" diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index a8c9c25c0..de3bc42cf 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -565,7 +565,7 @@ class parameterized_inputs(parameterize_sources): pass -class extract_columns(base.NodeExpander): +class extract_columns(base.SingleNodeNodeTransformer): def __init__(self, *columns: Union[Tuple[str, str], str], fill_with: Any = None): """Constructor for a modifier that expands a single function into the following nodes: @@ -577,6 +577,7 @@ def __init__(self, *columns: Union[Tuple[str, str], str], fill_with: Any = None) value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \ column. """ + super(extract_columns, self).__init__() if not columns: raise base.InvalidDecoratorException( "Error empty arguments passed to extract_columns decorator." @@ -599,7 +600,8 @@ def validate_return_type(fn: Callable): except NotImplementedError: raise base.InvalidDecoratorException( # TODO: capture was dataframe libraries are supported and print here. - f"Error {fn} does not output a type we know about. Is it a dataframe type we support?" + f"Error {fn} does not output a type we know about. Is it a dataframe type we " + f"support? " ) def validate(self, fn: Callable): @@ -610,13 +612,13 @@ def validate(self, fn: Callable): """ extract_columns.validate_return_type(fn) - def expand_node( + def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: """For each column to extract, output a node that extracts that column. Also, output the original dataframe generator. - - :param config: + :param node_: Node to transform + :param config: Config to use :param fn: Function to extract columns from. Must output a dataframe. :return: A collection of nodes -- one for the original dataframe generator, and another for each column to extract. @@ -692,7 +694,7 @@ def extractor_fn( return output_nodes -class extract_fields(base.NodeExpander): +class extract_fields(base.SingleNodeNodeTransformer): """Extracts fields from a dictionary of output.""" def __init__(self, fields: dict, fill_with: Any = None): @@ -706,6 +708,7 @@ def __init__(self, fields: dict, fill_with: Any = None): value? Or do you want to error out? Leave empty/None to error out, set fill_value to dynamically create a \ field value. """ + super(extract_fields, self).__init__() if not fields: raise base.InvalidDecoratorException( "Error an empty dict, or no dict, passed to extract_fields decorator." @@ -755,7 +758,7 @@ def validate(self, fn: Callable): f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}" ) - def expand_node( + def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: """For each field to extract, output a node that extracts that field. Also, output the original TypedDict @@ -924,7 +927,7 @@ def wrapper_fn(*args, _output_columns=parameterization.outputs, **kwargs): ) extract_columns_decorator = extract_columns(*parameterization.outputs) output_nodes.extend( - extract_columns_decorator.expand_node( + extract_columns_decorator.transform_node( parameterized_node, config, parameterized_node.callable ) ) diff --git a/tests/function_modifiers/test_combined.py b/tests/function_modifiers/test_combined.py new file mode 100644 index 000000000..acff4b513 --- /dev/null +++ b/tests/function_modifiers/test_combined.py @@ -0,0 +1,75 @@ +"""A few tests for combining different decorators. +While this should not be necessary -- we should be able to test the decorator lifecycle functions, +it is useful to have a few tests that demonstrate that common use-cases are supported. + +Note we also have some more end-to-end cases in test_layered.py""" +from typing import Dict + +import pandas as pd + +from hamilton.function_modifiers import base as fm_base +from hamilton.function_modifiers import extract_columns, extract_fields, subdag, tag + + +def test_subdag_and_extract_columns(): + def foo() -> pd.Series: + return pd.Series([1, 2, 3]) + + def bar() -> pd.Series: + return pd.Series([1, 2, 3]) + + @extract_columns("foo", "bar") + @subdag(foo, bar) + def foo_bar(foo: pd.Series, bar: pd.Series) -> pd.DataFrame: + return pd.DataFrame({"foo": foo, "bar": bar}) + + nodes = fm_base.resolve_nodes(foo_bar, {}) + nodes_by_name = {node.name: node for node in nodes} + assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"] + # The extraction columns should depend on the thing from which they are extracted + assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"] + assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"] + + +def test_subdag_and_extract_fields(): + def foo() -> int: + return 1 + + def bar() -> int: + return 2 + + @extract_fields({"foo": int, "bar": int}) + @subdag(foo, bar) + def foo_bar(foo: int, bar: pd.Series) -> Dict[str, int]: + return {"foo": foo, "bar": bar} + + nodes = fm_base.resolve_nodes(foo_bar, {}) + nodes_by_name = {node.name: node for node in nodes} + assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"] + # The extraction columns should depend on the thing from which they are extracted + assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"] + assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"] + + +def test_subdag_and_extract_fields_with_tags(): + def foo() -> int: + return 1 + + def bar() -> int: + return 2 + + @tag(a="b", target_="foo") + @tag(a="c", target_="bar") + @extract_fields({"foo": int, "bar": int}) + @subdag(foo, bar) + def foo_bar(foo: int, bar: pd.Series) -> Dict[str, int]: + return {"foo": foo, "bar": bar} + + nodes = fm_base.resolve_nodes(foo_bar, {}) + nodes_by_name = {node.name: node for node in nodes} + assert sorted(nodes_by_name) == ["bar", "foo", "foo_bar", "foo_bar.bar", "foo_bar.foo"] + # The extraction columns should depend on the thing from which they are extracted + assert sorted(nodes_by_name["foo"].input_types.keys()) == ["foo_bar"] + assert sorted(nodes_by_name["bar"].input_types.keys()) == ["foo_bar"] + assert nodes_by_name["foo"].tags["a"] == "b" + assert nodes_by_name["bar"].tags["a"] == "c" diff --git a/tests/function_modifiers/test_expanders.py b/tests/function_modifiers/test_expanders.py index 87b1e1594..1d191d215 100644 --- a/tests/function_modifiers/test_expanders.py +++ b/tests/function_modifiers/test_expanders.py @@ -228,7 +228,7 @@ def dummy_df_generator() -> pd.DataFrame: return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) nodes = list( - annotation.expand_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator) + annotation.transform_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator) ) assert len(nodes) == 3 assert nodes[0] == node.Node( @@ -258,7 +258,7 @@ def dummy_df() -> pd.DataFrame: return pd.DataFrame({"col_1": [1, 2, 3, 4], "col_2": [11, 12, 13, 14]}) annotation = function_modifiers.extract_columns("col_3", fill_with=0) - original_node, extracted_column_node = annotation.expand_node( + original_node, extracted_column_node = annotation.transform_node( node.Node.from_fn(dummy_df), {}, dummy_df ) original_df = original_node.callable() @@ -276,7 +276,7 @@ def dummy_df_generator() -> pd.DataFrame: annotation = function_modifiers.extract_columns("col_3") nodes = list( - annotation.expand_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator) + annotation.transform_node(node.Node.from_fn(dummy_df_generator), {}, dummy_df_generator) ) with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): nodes[1].callable(dummy_df_generator=dummy_df_generator()) @@ -348,7 +348,7 @@ def dummy_dict_generator() -> dict: return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])} nodes = list( - annotation.expand_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator) + annotation.transform_node(node.Node.from_fn(dummy_dict_generator), {}, dummy_dict_generator) ) assert len(nodes) == 4 assert nodes[0] == node.Node( @@ -378,7 +378,7 @@ def dummy_dict() -> dict: return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])} annotation = function_modifiers.extract_fields({"col_2": int, "col_4": float}, fill_with=1.0) - original_node, extracted_field_node, missing_field_node = annotation.expand_node( + original_node, extracted_field_node, missing_field_node = annotation.transform_node( node.Node.from_fn(dummy_dict), {}, dummy_dict ) original_dict = original_node.callable() @@ -394,7 +394,7 @@ def dummy_dict() -> dict: return {"col_1": [1, 2, 3, 4], "col_2": 1, "col_3": np.ndarray([1, 2, 3, 4])} annotation = function_modifiers.extract_fields({"col_4": int}) - nodes = list(annotation.expand_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) + nodes = list(annotation.transform_node(node.Node.from_fn(dummy_dict), {}, dummy_dict)) with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): nodes[1].callable(dummy_dict=dummy_dict()) From 523da3613ccb44985bbf3654e56f9dc52955da34 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 22 Mar 2023 19:54:08 -0700 Subject: [PATCH 2/2] Adds ability to specify external inputs in subdag/parameterized_subdag See https://github.com/DAGWorks-Inc/hamilton/issues/115 --- hamilton/function_modifiers/recursive.py | 29 ++++++++++-- tests/function_modifiers/test_recursive.py | 51 ++++++++++++++++++++++ 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index 566e4a70b..283ffa5c7 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -127,6 +127,10 @@ def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame: E.G. take a certain set of nodes, and run them with specified parameters. + @subdag declares parameters that are outputs of its subdags. Note that, if you want to use outputs of other + components of the DAG, you can use the `external_inputs` parameter to declare the parameters that do *not* come + from the subDAG. + Why might you want to use this? Let's take a look at some examples: 1. You have a feature engineering pipeline that you want to run on multiple datasets. If its exactly the same, \ @@ -148,6 +152,7 @@ def __init__( config: Dict[str, Any] = None, namespace: str = None, final_node_name: str = None, + external_inputs: List[str] = None, ): """Adds a subDAG to the main DAG. @@ -160,10 +165,14 @@ def __init__( this will default to the function name. :param final_node_name: Name of the final node in the subDAG. This is optional -- if not included, this will default to the function name. + :param external_inputs: Parameters in the function that are not produced by the functions + passed to the subdag. This is useful if you want to perform some logic with other inputs + in the subdag's processing function. """ self.subdag_functions = subdag.collect_functions(load_from) self.inputs = inputs if inputs is not None else {} self.config = config if config is not None else {} + self.external_inputs = external_inputs if external_inputs is not None else [] self._validate_config_inputs(self.config, self.inputs) self.namespace = namespace self.final_node_name = final_node_name @@ -307,9 +316,14 @@ def add_final_node(self, fn: Callable, node_name: str, namespace: str): :return: """ node_ = node.Node.from_fn(fn) - namespaced_input_map = {assign_namespace(key, namespace): key for key in node_.input_types} + namespaced_input_map = { + (assign_namespace(key, namespace) if key not in self.external_inputs else key): key + for key in node_.input_types + } + new_input_types = { - assign_namespace(key, namespace): value for key, value in node_.input_types.items() + (assign_namespace(key, namespace) if key not in self.external_inputs else key): value + for key, value in node_.input_types.items() } def new_function(**kwargs): @@ -377,11 +391,12 @@ def validate(self, fn): class SubdagParams(TypedDict): inputs: NotRequired[Dict[str, ParametrizedDependency]] config: NotRequired[Dict[str, Any]] + external_inputs: NotRequired[List[str]] class parameterized_subdag(base.NodeCreator): """parameterized subdag is when you want to create multiple subdags at one time. - Why do you want to do this? + Why might you want to do this? 1. You have multiple data sets you want to run the same feature engineering pipeline on. 2. You want to run some sort of optimization routine with a variety of results @@ -444,6 +459,7 @@ def __init__( str, Union[dependencies.ParametrizedDependency, dependencies.LiteralDependency] ] = None, config: Dict[str, Any] = None, + external_inputs: List[str] = None, **parameterization: SubdagParams, ): """Initializes a parameterized_subdag decorator. @@ -451,6 +467,8 @@ def __init__( :param load_from: Modules to load from :param inputs: Inputs for each subdag generated by the decorated function :param config: Config for each subdag generated by the decorated function + :param external_inputs: External inputs to all parameterized subdags. Note that + if you pass in any external inputs from local subdags, it overrides this (does not merge). :param parameterization: Parameterizations for each subdag generated. Note that this *overrides* any inputs/config passed to the decorator itself. @@ -460,12 +478,14 @@ def __init__( allowed to name these `load_from`, `inputs`, or `config`. That's a good thing, as these are not good names for variables anyway. - 2. Any empty items (not included) will default to an empty dict + 2. Any empty items (not included) will default to an empty dict (or an empty list in + the case of parameterization) """ self.load_from = load_from self.inputs = inputs if inputs is not None else {} self.config = config if config is not None else {} self.parameterization = parameterization + self.external_inputs = external_inputs if external_inputs is not None else [] def _gather_subdag_generators(self) -> List[subdag]: subdag_generators = [] @@ -475,6 +495,7 @@ def _gather_subdag_generators(self) -> List[subdag]: *self.load_from, inputs={**self.inputs, **parameterization.get("inputs", {})}, config={**self.config, **parameterization.get("config", {})}, + external_inputs=parameterization.get("external_inputs", self.external_inputs), namespace=key, final_node_name=key, ) diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index 8d61508d5..b1dcedbbf 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -307,3 +307,54 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: assert res["sum_all"] == sum_all( outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) ) + + +def test_subdag_with_external_nodes_input(): + def bar(input_1: int) -> int: + return input_1 + 1 + + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag(foo, bar, external_inputs=["baz"]) + def foo_bar_baz(foo: int, bar: int, baz: int) -> int: + return foo + bar + baz + + full_module = ad_hoc_utils.create_temporary_module(foo_bar_baz) + fg = graph.FunctionGraph(full_module, config={}) + # since we've provided it above, + assert "baz" in fg.nodes + assert fg.nodes["baz"].user_defined + res = fg.execute(nodes=[fg.nodes["foo_bar_baz"]], inputs={"input_1": 2, "input_2": 3, "baz": 4}) + assert res["foo_bar_baz"] == foo_bar_baz(foo(3), bar(2), 4) + + +def test_parameterized_subdag_with_external_inputs_global(): + def bar(input_1: int) -> int: + return input_1 + 1 + + def foo(input_2: int) -> int: + return input_2 + 1 + + @parameterized_subdag( + foo, + bar, + external_inputs=["baz"], + foo_bar_baz_input_1={"inputs": {"input_1": value(10), "input_2": value(20)}}, + foo_bar_baz_input_2={"inputs": {"input_1": value(30), "input_2": value(40)}}, + ) + def foo_bar_baz(foo: int, bar: int, baz: int) -> int: + return foo + bar + baz + + def foo_bar_baz_summed(foo_bar_baz_input_1: int, foo_bar_baz_input_2: int) -> int: + return foo_bar_baz_input_1 + foo_bar_baz_input_2 + + full_module = ad_hoc_utils.create_temporary_module(foo_bar_baz, foo_bar_baz_summed) + fg = graph.FunctionGraph(full_module, config={}) + # since we've provided it above, + assert "baz" in fg.nodes + assert fg.nodes["baz"].user_defined + res = fg.execute(nodes=[fg.nodes["foo_bar_baz_summed"]], inputs={"baz": 100}) + assert res["foo_bar_baz_summed"] == foo_bar_baz(foo(10), foo(20), 100) + foo_bar_baz( + bar(30), foo(40), 100 + )