From 208990a8af19dac894038f5db365845a0f3946b2 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Wed, 14 Dec 2022 00:33:48 -0800 Subject: [PATCH 1/4] Fixes does function wrapping #244 In issue #244 types were not being propagated of the function being wrapped. This changes that by adding the functools.wraps decorator. Adds a unit test to test for two decorators playing nice together and verified that it's broken before and fixed afterwards. --- hamilton/function_modifiers.py | 5 +++-- tests/resources/multiple_decorators_together.py | 14 ++++++++++++++ tests/test_graph.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 tests/resources/multiple_decorators_together.py diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index a2c75bba..7237218b 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -744,7 +744,8 @@ def generate_node(self, fn: Callable, config) -> node.Node: and the same parameters/types as the original function. """ - def replacing_function(__fn=fn, **kwargs): + @functools.wraps(fn) + def wrapper_function(**kwargs): final_kwarg_values = { key: param_spec.default for key, param_spec in inspect.signature(fn).parameters.items() @@ -754,7 +755,7 @@ def replacing_function(__fn=fn, **kwargs): final_kwarg_values = does.map_kwargs(final_kwarg_values, self.argument_mapping) return self.replacing_function(**final_kwarg_values) - return node.Node.from_fn(fn).copy_with(callabl=replacing_function) + return node.Node.from_fn(fn).copy_with(callabl=wrapper_function) class dynamic_transform(function_modifiers_base.NodeCreator): diff --git a/tests/resources/multiple_decorators_together.py b/tests/resources/multiple_decorators_together.py new file mode 100644 index 00000000..e9a67b2a --- /dev/null +++ b/tests/resources/multiple_decorators_together.py @@ -0,0 +1,14 @@ +import pandas as pd + +from hamilton.function_modifiers import does, extract_columns + + +def _sum_multiply(param0: int, param1: int, param2: int) -> pd.DataFrame: + return pd.DataFrame([{"param0a": param0, "param1b": param1, "param2c": param2}]) + + +@extract_columns("param1b") +@does(_sum_multiply) +def to_modify(param0: int, param1: int, param2: int = 2) -> pd.DataFrame: + """This sums the inputs it gets...""" + pass diff --git a/tests/test_graph.py b/tests/test_graph.py index 3d131630..85583dcb 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -14,6 +14,7 @@ import tests.resources.extract_columns_execution_count import tests.resources.functions_with_generics import tests.resources.layered_decorators +import tests.resources.multiple_decorators_together import tests.resources.optional_dependencies import tests.resources.parametrized_inputs import tests.resources.parametrized_nodes @@ -359,6 +360,20 @@ def test_end_to_end_with_column_extractor_nodes(): ) +def test_end_to_end_with_multiple_decorators(): + """Tests that a simple function graph with multiple decorators on a function works end-to-end""" + fg = graph.FunctionGraph( + tests.resources.multiple_decorators_together, config={"param0": 3, "param1": 1} + ) + nodes = fg.get_nodes() + results = fg.execute(nodes, {}, {}) + print(results) + df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2) + pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"]) + pd.testing.assert_frame_equal(results["to_modify"], df_expected) + assert nodes[0].documentation == "This sums the inputs it gets..." + + def test_end_to_end_with_config_modifier(): config = { "fn_1_version": 1, From cdea1f8062f88c3b2cd97101df3d975aaf1fb392 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Wed, 14 Dec 2022 10:29:07 -0800 Subject: [PATCH 2/4] Removes print statement --- tests/test_graph.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 85583dcb..e2dbf52a 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -367,7 +367,6 @@ def test_end_to_end_with_multiple_decorators(): ) nodes = fg.get_nodes() results = fg.execute(nodes, {}, {}) - print(results) df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2) pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"]) pd.testing.assert_frame_equal(results["to_modify"], df_expected) From 0f8ae0d86b3118253713f7b8d1a4ca0f188da772 Mon Sep 17 00:00:00 2001 From: "elijah.benizzy" Date: Wed, 14 Dec 2022 11:45:44 -0800 Subject: [PATCH 3/4] Cleaner fix for `@does`. The function is only useful for the node generator -- not for the rest. This ensures that we don't use it to handle signatures, rather, we just use the node's return types. --- hamilton/function_modifiers.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/hamilton/function_modifiers.py b/hamilton/function_modifiers.py index 7237218b..72c08a27 100644 --- a/hamilton/function_modifiers.py +++ b/hamilton/function_modifiers.py @@ -431,8 +431,7 @@ def expand_node( fn = node_.callable base_doc = node_.documentation - @functools.wraps(fn) - def df_generator(*args, **kwargs): + def df_generator(*args, **kwargs) -> pd.DataFrame: df_generated = fn(*args, **kwargs) if self.fill_with is not None: for col in self.columns: @@ -441,12 +440,8 @@ def df_generator(*args, **kwargs): return df_generated output_nodes = [ - node.Node( - node_.name, - typ=pd.DataFrame, - doc_string=base_doc, + node_.copy_with( callabl=df_generator, - tags=node_.tags.copy(), ) ] @@ -553,7 +548,6 @@ def expand_node( fn = node_.callable base_doc = node_.documentation - @functools.wraps(fn) def dict_generator(*args, **kwargs): dict_generated = fn(*args, **kwargs) if self.fill_with is not None: @@ -562,15 +556,7 @@ def dict_generator(*args, **kwargs): dict_generated[field] = self.fill_with return dict_generated - output_nodes = [ - node.Node( - node_.name, - typ=dict, - doc_string=base_doc, - callabl=dict_generator, - tags=node_.tags.copy(), - ) - ] + output_nodes = [node_.copy_with(callabl=dict_generator)] for field, field_type in self.fields.items(): doc_string = base_doc # default doc string of base function. @@ -744,7 +730,7 @@ def generate_node(self, fn: Callable, config) -> node.Node: and the same parameters/types as the original function. """ - @functools.wraps(fn) + # @functools.wraps(fn) def wrapper_function(**kwargs): final_kwarg_values = { key: param_spec.default From c0343d8e1a11fb70855461b3541142108bef09d1 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 17 Dec 2022 21:37:01 -0800 Subject: [PATCH 4/4] Adds unit test to cover extract_fields with does Adds a test to cover Elijah's changes. In the process I created https://github.com/stitchfix/hamilton/issues/249. Since I think we need to add more unit tests around decorator interactions. --- .../resources/multiple_decorators_together.py | 27 ++++++++++++- tests/test_graph.py | 40 ++++++++++++++++++- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/tests/resources/multiple_decorators_together.py b/tests/resources/multiple_decorators_together.py index e9a67b2a..d911d967 100644 --- a/tests/resources/multiple_decorators_together.py +++ b/tests/resources/multiple_decorators_together.py @@ -1,14 +1,37 @@ import pandas as pd -from hamilton.function_modifiers import does, extract_columns +from hamilton.function_modifiers import does, extract_columns, extract_fields, tag def _sum_multiply(param0: int, param1: int, param2: int) -> pd.DataFrame: return pd.DataFrame([{"param0a": param0, "param1b": param1, "param2c": param2}]) +def _sum(param0: int, param1: int, param2: int) -> dict: + return {"total": param0 + param1 + param2} + + @extract_columns("param1b") @does(_sum_multiply) def to_modify(param0: int, param1: int, param2: int = 2) -> pd.DataFrame: - """This sums the inputs it gets...""" + """This is a dummy function showing extract_columns with does.""" + pass + + +@extract_fields({"total": int}) +@does(_sum) +def to_modify_2(param0: int, param1: int, param2: int = 2) -> dict: + """This is a dummy function showing extract_fields with does.""" + pass + + +def _dummy(**values) -> dict: + return {f"out_{k.split('_')[1]}": v for k, v in values.items()} + + +@extract_fields({"out_value1": int, "out_value2": str}) +@tag(test_key="test-value") +# @check_output(data_type=dict, importance="fail") To fix see https://github.com/stitchfix/hamilton/issues/249 +@does(_dummy) +def uber_decorated_function(in_value1: int, in_value2: str) -> dict: pass diff --git a/tests/test_graph.py b/tests/test_graph.py index e2dbf52a..c556ab09 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -363,14 +363,50 @@ def test_end_to_end_with_column_extractor_nodes(): def test_end_to_end_with_multiple_decorators(): """Tests that a simple function graph with multiple decorators on a function works end-to-end""" fg = graph.FunctionGraph( - tests.resources.multiple_decorators_together, config={"param0": 3, "param1": 1} + tests.resources.multiple_decorators_together, + config={"param0": 3, "param1": 1, "in_value1": 42, "in_value2": "string_value"}, ) nodes = fg.get_nodes() + # To help debug issues: + # nodez, user_nodes = fg.get_upstream_nodes([n.name for n in nodes], + # {"param0": 3, "param1": 1, + # "in_value1": 42, "in_value2": "string_value"}) + # fg.display( + # nodez, + # user_nodes, + # "all_multiple_decorators", + # render_kwargs=None, + # graphviz_kwargs=None, + # ) results = fg.execute(nodes, {}, {}) df_expected = tests.resources.multiple_decorators_together._sum_multiply(3, 1, 2) + dict_expected = tests.resources.multiple_decorators_together._sum(3, 1, 2) pd.testing.assert_series_equal(results["param1b"], df_expected["param1b"]) pd.testing.assert_frame_equal(results["to_modify"], df_expected) - assert nodes[0].documentation == "This sums the inputs it gets..." + assert results["total"] == dict_expected["total"] + assert results["to_modify_2"] == dict_expected + node_dict = {n.name: n for n in nodes} + print(sorted(list(node_dict.keys()))) + assert ( + node_dict["to_modify"].documentation + == "This is a dummy function showing extract_columns with does." + ) + assert ( + node_dict["to_modify_2"].documentation + == "This is a dummy function showing extract_fields with does." + ) + # tag only applies right now to outer most node layer + assert node_dict["uber_decorated_function"].tags == { + "module": "tests.resources.multiple_decorators_together" + } # tags are not propagated + assert node_dict["out_value1"].tags == { + "module": "tests.resources.multiple_decorators_together", + "test_key": "test-value", + } + assert node_dict["out_value2"].tags == { + "module": "tests.resources.multiple_decorators_together", + "test_key": "test-value", + } def test_end_to_end_with_config_modifier():