Skip to content

Commit

Permalink
Adds Typeddict Extract fields subclass type check and test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
skrawcz committed Dec 12, 2024
1 parent 1e7e136 commit 2f3e9b9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
13 changes: 10 additions & 3 deletions hamilton/function_modifiers/expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import typing_extensions
import typing_inspect

from hamilton import node, registry
from hamilton import htypes, node, registry
from hamilton.dev_utils import deprecation
from hamilton.function_modifiers import base
from hamilton.function_modifiers.dependencies import (
Expand Down Expand Up @@ -772,8 +772,15 @@ def validate(self, fn: Callable):
else:
# check that fields is a subset of TypedDict that is defined
typed_dict_fields = typing.get_type_hints(output_type)
for k, v in self.fields.items():
if typed_dict_fields.get(k, None) != v:
for field_name, field_type in self.fields.items():
expected_type = typed_dict_fields.get(field_name, None)
if expected_type == field_type:
pass # we're definitely good
elif expected_type is not None and htypes.custom_subclass_check(
field_type, expected_type
):
pass
else:
raise base.InvalidDecoratorException(
f"Error {self.fields} did not match a subset of the TypedDict annotation's fields {typed_dict_fields}."
)
Expand Down
35 changes: 35 additions & 0 deletions tests/function_modifiers/test_expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,41 @@ def return_dict() -> return_type:
annotation.validate(return_dict)


class SomeObject:
pass


class InheritedObject(SomeObject):
pass


class MyDictInheritance(TypedDict):
test: SomeObject
test2: str


class MyDictInheritanceBadCase(TypedDict):
test: InheritedObject
test2: str


def test_extract_fields_validate_happy_inheritance():
def return_dict() -> MyDictInheritance:
return {}

annotation = function_modifiers.extract_fields({"test": InheritedObject})
annotation.validate(return_dict)


def test_extract_fields_validate_not_subclass():
def return_dict() -> MyDictInheritanceBadCase:
return {}

annotation = function_modifiers.extract_fields({"test": SomeObject})
with pytest.raises(base.InvalidDecoratorException):
annotation.validate(return_dict)


@pytest.mark.parametrize(
"return_type",
[(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)],
Expand Down

0 comments on commit 2f3e9b9

Please sign in to comment.