diff --git a/hamilton/htypes.py b/hamilton/htypes.py index 2aac81920..74bf4789c 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -321,6 +321,23 @@ def check_input_type(node_type: Type, input_value: Any) -> bool: return any([check_input_type(ut, input_value) for ut in union_types]) elif node_type == type(input_value): return True + # check for literal and that the value is in the literals listed. + elif typing_inspect.is_literal_type(node_type) and input_value in typing_inspect.get_args( + node_type + ): + return True + # check for sequence and that the value is a sequence + elif ( + typing_inspect.is_generic_type(node_type) + and typing_inspect.get_origin(node_type) + in (list, tuple, set, typing_inspect.get_origin(typing.Sequence)) + and isinstance(input_value, (list, tuple, set, typing_inspect.get_origin(typing.Sequence))) + ): + if typing_inspect.get_args(node_type): + # check first value in sequence -- if the type is specified. + for i in input_value: # this handles empty input case, e.g. [] or (), set() + return check_input_type(typing_inspect.get_args(node_type)[0], i) + return True return False diff --git a/tests/test_type_utils.py b/tests/test_type_utils.py index 9d0747153..66abe5fcd 100644 --- a/tests/test_type_utils.py +++ b/tests/test_type_utils.py @@ -239,6 +239,9 @@ def test_check_input_type_mismatch(node_type, input_value): (str, "abc"), (typing.Union[int, pd.Series], pd.Series([1, 2, 3])), (typing.Union[int, pd.Series], 1), + (typing.Literal["csv", "prq"], "csv"), + (typing.Sequence[str], ("a", "b")), + (typing.Sequence, ("a", "b")), ], ids=[ "test-any", @@ -253,6 +256,9 @@ def test_check_input_type_mismatch(node_type, input_value): "test-type-match-str", "test-union-match-series", "test-union-match-int", + "test-literal-match-str", + "test-sequence-str-match-tuple-str", + "test-sequence-match-tuple-str", ], ) def test_check_input_type_match(node_type, input_value):