Skip to content

Commit

Permalink
Fixes Annotated use as input type. Fixes #554 (#557)
Browse files Browse the repository at this point in the history
* Fixes Annotated use as input type. Fixes #554

Turns out we were not inspecting the input types like we do the output types.
This change fixes that and adds a unit test to ensure the types are what
we think they should be.

Note I had to change spark a little, since we didn't take into account
 that annotated would be passed all the way through in the spark code.
 This fixes that and adds a unit test for it.
  • Loading branch information
skrawcz authored Nov 25, 2023
1 parent b0887e0 commit 5830ef0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
4 changes: 3 additions & 1 deletion hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __init__(
for key, value in input_types.items()
}
else:
input_types = typing.get_type_hints(callabl)
# TODO -- remove this when we no longer support 3.8 -- 10/14/2024
type_hint_kwargs = {} if sys.version_info < (3, 9) else {"include_extras": True}
input_types = typing.get_type_hints(callabl, **type_hint_kwargs)
signature = inspect.signature(callabl)
for key, value in signature.parameters.items():
if key not in input_types:
Expand Down
8 changes: 7 additions & 1 deletion hamilton/plugins/h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,14 @@ def _get_pandas_annotations(node_: node.Node, bound_parameters: Dict[str, Any])
:param hamilton_udf: the function to check.
:return: dictionary of parameter names to boolean indicating if they are pandas series.
"""

def _get_type_from_annotation(annotation: Any) -> Any:
"""Gets the type from the annotation if there is one."""
actual_type, extras = htypes.get_type_information(annotation)
return actual_type

return {
name: type_ == pd.Series
name: _get_type_from_annotation(type_) == pd.Series
for name, (type_, dep_type) in node_.input_types.items()
if name not in bound_parameters and dep_type == node.DependencyType.REQUIRED
}
Expand Down
9 changes: 9 additions & 0 deletions plugin_tests/h_spark/test_h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def with_pandas(a: pd.Series) -> pd.Series:
assert h_spark._get_pandas_annotations(node.Node.from_fn(with_pandas), {}) == {"a": True}


def test__get_pandas_annotations_with_annotated_pandas():
IntSeries = htypes.column[pd.Series, int]

def with_pandas(a: IntSeries) -> IntSeries:
return a * 2

assert h_spark._get_pandas_annotations(node.Node.from_fn(with_pandas), {}) == {"a": True}


def test__get_pandas_annotations_with_pandas_and_other_default():
def with_pandas_and_other_default(a: pd.Series, b: int) -> pd.Series:
return a * b
Expand Down
30 changes: 29 additions & 1 deletion tests/test_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import sys
from typing import Any, Literal, TypeVar

import numpy as np
import numpy.typing as npt
import pytest

from hamilton.node import Node
from hamilton.node import DependencyType, Node


def test_node_from_fn_happy():
Expand Down Expand Up @@ -33,3 +38,26 @@ def fn() -> int:
node_copy_copy = node_copy.copy_with(name="rename_fn_again")
assert node_copy_copy.originating_functions == (fn,)
assert node_copy_copy.name == "rename_fn_again"


@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher")
def test_node_handles_annotated():
from typing import Annotated

DType = TypeVar("DType", bound=np.generic)
ArrayN = Annotated[npt.NDArray[DType], Literal["N"]]

def annotated_func(first: ArrayN[np.float64], other: float = 2.0) -> ArrayN[np.float64]:
return first * other

node = Node.from_fn(annotated_func)
assert node.name == "annotated_func"
expected = {
"first": (
Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]],
DependencyType.REQUIRED,
),
"other": (float, DependencyType.OPTIONAL),
}
assert node.input_types == expected
assert node.type == Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]]

0 comments on commit 5830ef0

Please sign in to comment.