Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes Annotated use as input type. Fixes #554 #557

Merged
merged 2 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def __init__(
for key, value in input_types.items()
}
else:
input_types = typing.get_type_hints(callabl)
# TODO -- remove this when we no longer support 3.8 -- 10/14/2024
type_hint_kwargs = {} if sys.version_info < (3, 9) else {"include_extras": True}
input_types = typing.get_type_hints(callabl, **type_hint_kwargs)
signature = inspect.signature(callabl)
for key, value in signature.parameters.items():
if key not in input_types:
Expand Down
8 changes: 7 additions & 1 deletion hamilton/plugins/h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,14 @@ def _get_pandas_annotations(node_: node.Node, bound_parameters: Dict[str, Any])
:param hamilton_udf: the function to check.
:return: dictionary of parameter names to boolean indicating if they are pandas series.
"""

def _get_type_from_annotation(annotation: Any) -> Any:
"""Gets the type from the annotation if there is one."""
actual_type, extras = htypes.get_type_information(annotation)
return actual_type

return {
name: type_ == pd.Series
name: _get_type_from_annotation(type_) == pd.Series
for name, (type_, dep_type) in node_.input_types.items()
if name not in bound_parameters and dep_type == node.DependencyType.REQUIRED
}
Expand Down
9 changes: 9 additions & 0 deletions plugin_tests/h_spark/test_h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def with_pandas(a: pd.Series) -> pd.Series:
assert h_spark._get_pandas_annotations(node.Node.from_fn(with_pandas), {}) == {"a": True}


def test__get_pandas_annotations_with_annotated_pandas():
IntSeries = htypes.column[pd.Series, int]

def with_pandas(a: IntSeries) -> IntSeries:
return a * 2

assert h_spark._get_pandas_annotations(node.Node.from_fn(with_pandas), {}) == {"a": True}


def test__get_pandas_annotations_with_pandas_and_other_default():
def with_pandas_and_other_default(a: pd.Series, b: int) -> pd.Series:
return a * b
Expand Down
30 changes: 29 additions & 1 deletion tests/test_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import sys
from typing import Any, Literal, TypeVar

import numpy as np
import numpy.typing as npt
import pytest

from hamilton.node import Node
from hamilton.node import DependencyType, Node


def test_node_from_fn_happy():
Expand Down Expand Up @@ -33,3 +38,26 @@ def fn() -> int:
node_copy_copy = node_copy.copy_with(name="rename_fn_again")
assert node_copy_copy.originating_functions == (fn,)
assert node_copy_copy.name == "rename_fn_again"


@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher")
def test_node_handles_annotated():
from typing import Annotated

DType = TypeVar("DType", bound=np.generic)
ArrayN = Annotated[npt.NDArray[DType], Literal["N"]]

def annotated_func(first: ArrayN[np.float64], other: float = 2.0) -> ArrayN[np.float64]:
return first * other

node = Node.from_fn(annotated_func)
assert node.name == "annotated_func"
expected = {
"first": (
Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]],
DependencyType.REQUIRED,
),
"other": (float, DependencyType.OPTIONAL),
}
assert node.input_types == expected
assert node.type == Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]]