Skip to content

Commit

Permalink
Fixes edge cases input type checking missed
Browse files Browse the repository at this point in the history
This handles sequences and literals.
  • Loading branch information
skrawcz committed Aug 12, 2024
1 parent f35f6b7 commit 6dec56e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
17 changes: 17 additions & 0 deletions hamilton/htypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,23 @@ def check_input_type(node_type: Type, input_value: Any) -> bool:
return any([check_input_type(ut, input_value) for ut in union_types])
elif node_type == type(input_value):
return True
# check for literal and that the value is in the literals listed.
elif typing_inspect.is_literal_type(node_type) and input_value in typing_inspect.get_args(
node_type
):
return True
# check for sequence and that the value is a sequence
elif (
typing_inspect.is_generic_type(node_type)
and typing_inspect.get_origin(node_type)
in (list, tuple, set, typing_inspect.get_origin(typing.Sequence))
and isinstance(input_value, (list, tuple, set, typing_inspect.get_origin(typing.Sequence)))
):
if typing_inspect.get_args(node_type):
# check first value in sequence -- if the type is specified.
for i in input_value: # this handles empty input case, e.g. [] or (), set()
return check_input_type(typing_inspect.get_args(node_type)[0], i)
return True
return False


Expand Down
6 changes: 6 additions & 0 deletions tests/test_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def test_check_input_type_mismatch(node_type, input_value):
(str, "abc"),
(typing.Union[int, pd.Series], pd.Series([1, 2, 3])),
(typing.Union[int, pd.Series], 1),
(typing.Literal["csv", "prq"], "csv"),
(typing.Sequence[str], ("a", "b")),
(typing.Sequence, ("a", "b")),
],
ids=[
"test-any",
Expand All @@ -253,6 +256,9 @@ def test_check_input_type_mismatch(node_type, input_value):
"test-type-match-str",
"test-union-match-series",
"test-union-match-int",
"test-literal-match-str",
"test-sequence-str-match-tuple-str",
"test-sequence-match-tuple-str",
],
)
def test_check_input_type_match(node_type, input_value):
Expand Down

0 comments on commit 6dec56e

Please sign in to comment.