diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index 2f975f49f..11133822d 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -17,6 +17,7 @@ ParametrizedDependency, UpstreamDependency, ) +from hamilton.htypes import custom_subclass_check from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver from hamilton.node import DependencyType from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY @@ -733,7 +734,7 @@ def load_json_data(json_path: str = "data/my_data.json") -> tuple[pd.DataFrame, def validate(self, fn: Callable): """Validates that the output type is correctly annotated.""" - return_annotation = inspect.signature(fn).return_annotation + return_annotation = typing.get_type_hints(fn).get("return") if return_annotation is inspect.Signature.empty: raise InvalidDecoratorException( f"Function: {fn.__qualname__} must have a return annotation." @@ -748,10 +749,17 @@ def validate(self, fn: Callable): ) # check that the second is a dict second_arg = typing_inspect.get_args(return_annotation)[1] - if not (second_arg == dict or second_arg == Dict): + if not (custom_subclass_check(second_arg, dict)): raise InvalidDecoratorException( f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)." ) + second_arg_params = typing_inspect.get_args(second_arg) + if ( + len(second_arg_params) > 0 and not second_arg_params[0] == str + ): # metadata must have string keys + raise InvalidDecoratorException( + f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]" + ) def generate_nodes(self, fn: Callable, config) -> List[node.Node]: """Generates two nodes. We have to add tags appropriately. diff --git a/hamilton/version.py b/hamilton/version.py index 938b3df22..e5a652705 100644 --- a/hamilton/version.py +++ b/hamilton/version.py @@ -1 +1 @@ -VERSION = (1, 85, 0) +VERSION = (1, 85, 1) diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index 08f2bb067..06ba57b6b 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -19,6 +19,7 @@ resolve_kwargs, ) from hamilton.function_modifiers.base import DefaultNodeCreator +from hamilton.htypes import custom_subclass_check from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.default_data_loaders import JSONDataSaver from hamilton.registry import LOADER_REGISTRY @@ -696,19 +697,23 @@ def fn(data1: dict, data2: dict) -> dict: import sys if sys.version_info >= (3, 9): - dl_type = tuple[int, dict] - ds_type = dict + dict_ = dict + tuple_ = tuple else: - dl_type = Tuple[int, Dict] - ds_type = Dict + dict_ = Dict + tuple_ = Tuple # Mock functions for dataloader & datasaver testing -def correct_dl_function(foo: int) -> dl_type: +def correct_dl_function(foo: int) -> tuple_[int, dict_]: return 1, {} -def correct_ds_function(data: float) -> ds_type: +def correct_dl_function_with_subscripts(foo: int) -> tuple_[Dict[str, int], Dict[str, str]]: + return {"a": 1}, {"b": "c"} + + +def correct_ds_function(data: float) -> dict_: return {} @@ -720,19 +725,24 @@ def non_tuple_return_function() -> int: return 1 -def incorrect_tuple_length_function() -> Tuple[int]: +def incorrect_tuple_length_function() -> tuple_[int]: return (1,) -def incorrect_second_element_function() -> Tuple[int, list]: +def incorrect_second_element_function() -> tuple_[int, list]: return 1, [] +def incorrect_dict_subscript() -> tuple_[int, Dict[int, str]]: + return 1, {1: "a"} + + incorrect_funcs = [ no_return_annotation_function, non_tuple_return_function, incorrect_tuple_length_function, incorrect_second_element_function, + incorrect_dict_subscript, ] @@ -743,6 +753,10 @@ def test_dl_validate_incorrect_functions(func): dl.validate(func) +@pytest.mark.skipif( + sys.version_info < (3, 9, 0), + reason="dataloader not guarenteed to work with subscripted tuples on 3.8", +) def test_dl_validate_with_correct_function(): dl = dataloader() try: @@ -752,6 +766,15 @@ def test_dl_validate_with_correct_function(): pytest.fail("validate() raised InvalidDecoratorException unexpectedly!") +def test_dl_validate_with_subscripts(): + dl = dataloader() + try: + dl.validate(correct_dl_function_with_subscripts) + except InvalidDecoratorException: + # i.e. fail the test if there's an error + pytest.fail("validate() raised InvalidDecoratorException unexpectedly!") + + def test_ds_validate_with_correct_function(): dl = datasaver() try: @@ -792,6 +815,18 @@ def test_dataloader(): } +def test_dataloader_future_annotations(): + from tests.resources import nodes_with_future_annotation + + fn_to_collect = nodes_with_future_annotation.sample_dataloader + fg = graph.create_function_graph( + ad_hoc_utils.create_temporary_module(fn_to_collect), + config={}, + ) + # the data loaded is a list + assert custom_subclass_check(fg["sample_dataloader"].type, list) + + def test_datasaver(): annotation = datasaver() (node1,) = annotation.generate_nodes(correct_ds_function, {}) diff --git a/tests/resources/nodes_with_future_annotation.py b/tests/resources/nodes_with_future_annotation.py index 885c8b834..b8566d362 100644 --- a/tests/resources/nodes_with_future_annotation.py +++ b/tests/resources/nodes_with_future_annotation.py @@ -1,9 +1,16 @@ from __future__ import annotations +import sys +from typing import List, Tuple + +from hamilton.function_modifiers import dataloader from hamilton.htypes import Collect, Parallelizable """Tests future annotations with common node types""" +tuple_ = Tuple if sys.version_info < (3, 9, 0) else tuple +list_ = List if sys.version_info < (3, 9, 0) else list + def parallelized() -> Parallelizable[int]: yield 1 @@ -17,3 +24,9 @@ def standard(parallelized: int) -> int: def collected(standard: Collect[int]) -> int: return sum(standard) + + +@dataloader() +def sample_dataloader() -> tuple_[list_[str], dict]: + """Grouping here as the rest test annotations""" + return ["a", "b", "c"], {}