Skip to content

Commit

Permalink
Fixes issue in which dataloader did not accept subscripted generics a…
Browse files Browse the repository at this point in the history
…s the output type

Note we only allow dict[str, ...], nothing else.
  • Loading branch information
elijahbenizzy committed Dec 17, 2024
1 parent 65f6bcd commit 20bcab8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
10 changes: 9 additions & 1 deletion hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ParametrizedDependency,
UpstreamDependency,
)
from hamilton.htypes import custom_subclass_check
from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver
from hamilton.node import DependencyType
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY
Expand Down Expand Up @@ -748,10 +749,17 @@ def validate(self, fn: Callable):
)
# check that the second is a dict
second_arg = typing_inspect.get_args(return_annotation)[1]
if not (second_arg == dict or second_arg == Dict):
if not (custom_subclass_check(second_arg, dict)):
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)."
)
second_arg_params = typing_inspect.get_args(second_arg)
if (
len(second_arg_params) > 0 and not second_arg_params[0] == str
): # metadata must have string keys
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]"
)

def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
"""Generates two nodes. We have to add tags appropriately.
Expand Down
34 changes: 26 additions & 8 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,19 +697,23 @@ def fn(data1: dict, data2: dict) -> dict:
import sys

if sys.version_info >= (3, 9):
dl_type = tuple[int, dict]
ds_type = dict
dict_ = dict
tuple_ = tuple
else:
dl_type = Tuple[int, Dict]
ds_type = Dict
dict_ = Dict
tuple_ = Tuple


# Mock functions for dataloader & datasaver testing
def correct_dl_function(foo: int) -> dl_type:
def correct_dl_function(foo: int) -> tuple_[int, dict_]:
return 1, {}


def correct_ds_function(data: float) -> ds_type:
def correct_dl_function_with_subscripts(foo: int) -> tuple_[Dict[str, int], Dict[str, str]]:
return {"a": 1}, {"b": "c"}


def correct_ds_function(data: float) -> dict_:
return {}


Expand All @@ -721,19 +725,24 @@ def non_tuple_return_function() -> int:
return 1


def incorrect_tuple_length_function() -> Tuple[int]:
def incorrect_tuple_length_function() -> tuple_[int]:
return (1,)


def incorrect_second_element_function() -> Tuple[int, list]:
def incorrect_second_element_function() -> tuple_[int, list]:
return 1, []


def incorrect_dict_subscript() -> tuple_[int, Dict[int, str]]:
return 1, {1: "a"}


incorrect_funcs = [
no_return_annotation_function,
non_tuple_return_function,
incorrect_tuple_length_function,
incorrect_second_element_function,
incorrect_dict_subscript,
]


Expand All @@ -753,6 +762,15 @@ def test_dl_validate_with_correct_function():
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")


def test_dl_validate_with_subscripts():
dl = dataloader()
try:
dl.validate(correct_dl_function_with_subscripts)
except InvalidDecoratorException:
# i.e. fail the test if there's an error
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")


def test_ds_validate_with_correct_function():
dl = datasaver()
try:
Expand Down

0 comments on commit 20bcab8

Please sign in to comment.