Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes dataloader to use the correct type hinting #1261

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ParametrizedDependency,
UpstreamDependency,
)
from hamilton.htypes import custom_subclass_check
from hamilton.io.data_adapters import AdapterCommon, DataLoader, DataSaver
from hamilton.node import DependencyType
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY
Expand Down Expand Up @@ -733,7 +734,7 @@ def load_json_data(json_path: str = "data/my_data.json") -> tuple[pd.DataFrame,

def validate(self, fn: Callable):
"""Validates that the output type is correctly annotated."""
return_annotation = inspect.signature(fn).return_annotation
return_annotation = typing.get_type_hints(fn).get("return")
if return_annotation is inspect.Signature.empty:
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must have a return annotation."
Expand All @@ -748,10 +749,17 @@ def validate(self, fn: Callable):
)
# check that the second is a dict
second_arg = typing_inspect.get_args(return_annotation)[1]
if not (second_arg == dict or second_arg == Dict):
if not (custom_subclass_check(second_arg, dict)):
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict)."
)
second_arg_params = typing_inspect.get_args(second_arg)
if (
len(second_arg_params) > 0 and not second_arg_params[0] == str
): # metadata must have string keys
raise InvalidDecoratorException(
f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]"
)

def generate_nodes(self, fn: Callable, config) -> List[node.Node]:
"""Generates two nodes. We have to add tags appropriately.
Expand Down
2 changes: 1 addition & 1 deletion hamilton/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = (1, 85, 0)
VERSION = (1, 85, 1)
51 changes: 43 additions & 8 deletions tests/function_modifiers/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
resolve_kwargs,
)
from hamilton.function_modifiers.base import DefaultNodeCreator
from hamilton.htypes import custom_subclass_check
from hamilton.io.data_adapters import DataLoader, DataSaver
from hamilton.io.default_data_loaders import JSONDataSaver
from hamilton.registry import LOADER_REGISTRY
Expand Down Expand Up @@ -696,19 +697,23 @@ def fn(data1: dict, data2: dict) -> dict:
import sys

if sys.version_info >= (3, 9):
dl_type = tuple[int, dict]
ds_type = dict
dict_ = dict
tuple_ = tuple
else:
dl_type = Tuple[int, Dict]
ds_type = Dict
dict_ = Dict
tuple_ = Tuple


# Mock functions for dataloader & datasaver testing
def correct_dl_function(foo: int) -> dl_type:
def correct_dl_function(foo: int) -> tuple_[int, dict_]:
return 1, {}


def correct_ds_function(data: float) -> ds_type:
def correct_dl_function_with_subscripts(foo: int) -> tuple_[Dict[str, int], Dict[str, str]]:
return {"a": 1}, {"b": "c"}


def correct_ds_function(data: float) -> dict_:
return {}


Expand All @@ -720,19 +725,24 @@ def non_tuple_return_function() -> int:
return 1


def incorrect_tuple_length_function() -> Tuple[int]:
def incorrect_tuple_length_function() -> tuple_[int]:
return (1,)


def incorrect_second_element_function() -> Tuple[int, list]:
def incorrect_second_element_function() -> tuple_[int, list]:
return 1, []


def incorrect_dict_subscript() -> tuple_[int, Dict[int, str]]:
return 1, {1: "a"}


incorrect_funcs = [
no_return_annotation_function,
non_tuple_return_function,
incorrect_tuple_length_function,
incorrect_second_element_function,
incorrect_dict_subscript,
]


Expand All @@ -743,6 +753,10 @@ def test_dl_validate_incorrect_functions(func):
dl.validate(func)


@pytest.mark.skipif(
sys.version_info < (3, 9, 0),
reason="dataloader not guarenteed to work with subscripted tuples on 3.8",
)
def test_dl_validate_with_correct_function():
dl = dataloader()
try:
Expand All @@ -752,6 +766,15 @@ def test_dl_validate_with_correct_function():
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")


def test_dl_validate_with_subscripts():
dl = dataloader()
try:
dl.validate(correct_dl_function_with_subscripts)
except InvalidDecoratorException:
# i.e. fail the test if there's an error
pytest.fail("validate() raised InvalidDecoratorException unexpectedly!")


def test_ds_validate_with_correct_function():
dl = datasaver()
try:
Expand Down Expand Up @@ -792,6 +815,18 @@ def test_dataloader():
}


def test_dataloader_future_annotations():
from tests.resources import nodes_with_future_annotation

fn_to_collect = nodes_with_future_annotation.sample_dataloader
fg = graph.create_function_graph(
ad_hoc_utils.create_temporary_module(fn_to_collect),
config={},
)
# the data loaded is a list
assert custom_subclass_check(fg["sample_dataloader"].type, list)


def test_datasaver():
annotation = datasaver()
(node1,) = annotation.generate_nodes(correct_ds_function, {})
Expand Down
13 changes: 13 additions & 0 deletions tests/resources/nodes_with_future_annotation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import annotations

import sys
from typing import List, Tuple

from hamilton.function_modifiers import dataloader
from hamilton.htypes import Collect, Parallelizable

"""Tests future annotations with common node types"""

tuple_ = Tuple if sys.version_info < (3, 9, 0) else tuple
elijahbenizzy marked this conversation as resolved.
Show resolved Hide resolved
list_ = List if sys.version_info < (3, 9, 0) else list


def parallelized() -> Parallelizable[int]:
yield 1
Expand All @@ -17,3 +24,9 @@ def standard(parallelized: int) -> int:

def collected(standard: Collect[int]) -> int:
return sum(standard)


@dataloader()
def sample_dataloader() -> tuple_[list_[str], dict]:
"""Grouping here as the rest test annotations"""
return ["a", "b", "c"], {}