From 8691e430cdb6e7df49c75eccd74d33591672031a Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Tue, 17 Dec 2024 12:24:55 -0800 Subject: [PATCH] Fixes issue in which dataloader did not accept subscripted generics as the output type Note we only allow dict[str, ...], nothing else. --- hamilton/function_modifiers/adapters.py | 10 +++++++- tests/function_modifiers/test_adapters.py | 28 ++++++++++++++++++----- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index ea6aa0d10..11133822d 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -17,6 +17,7 @@ ParametrizedDependency, UpstreamDependency, ) +from hamilton.htypes import custom_subclass_check from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver from hamilton.node import DependencyType from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY @@ -748,10 +749,17 @@ def validate(self, fn: Callable): ) # check that the second is a dict second_arg = typing_inspect.get_args(return_annotation)[1] - if not (second_arg == dict or second_arg == Dict): + if not (custom_subclass_check(second_arg, dict)): raise InvalidDecoratorException( f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)." ) + second_arg_params = typing_inspect.get_args(second_arg) + if ( + len(second_arg_params) > 0 and not second_arg_params[0] == str + ): # metadata must have string keys + raise InvalidDecoratorException( + f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]" + ) def generate_nodes(self, fn: Callable, config) -> List[node.Node]: """Generates two nodes. We have to add tags appropriately. diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index 06307dd96..d44ab5842 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -697,19 +697,21 @@ def fn(data1: dict, data2: dict) -> dict: import sys if sys.version_info >= (3, 9): - dl_type = tuple[int, dict] - ds_type = dict + dict_ = dict else: - dl_type = Tuple[int, Dict] - ds_type = Dict + dict_ = Dict # Mock functions for dataloader & datasaver testing -def correct_dl_function(foo: int) -> dl_type: +def correct_dl_function(foo: int) -> Tuple[int, dict_]: return 1, {} -def correct_ds_function(data: float) -> ds_type: +def correct_dl_function_with_subscripts(foo: int) -> tuple[Dict[str, int], Dict[str, str]]: + return {"a": 1}, {"b": "c"} + + +def correct_ds_function(data: float) -> dict_: return {} @@ -729,11 +731,16 @@ def incorrect_second_element_function() -> Tuple[int, list]: return 1, [] +def incorrect_dict_subscript() -> Tuple[int, Dict[int, str]]: + return 1, {1: "a"} + + incorrect_funcs = [ no_return_annotation_function, non_tuple_return_function, incorrect_tuple_length_function, incorrect_second_element_function, + incorrect_dict_subscript, ] @@ -753,6 +760,15 @@ def test_dl_validate_with_correct_function(): pytest.fail("validate() raised InvalidDecoratorException unexpectedly!") +def test_dl_validate_with_subscripts(): + dl = dataloader() + try: + dl.validate(correct_dl_function_with_subscripts) + except InvalidDecoratorException: + # i.e. fail the test if there's an error + pytest.fail("validate() raised InvalidDecoratorException unexpectedly!") + + def test_ds_validate_with_correct_function(): dl = datasaver() try: