Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix future annotations parallelizable #1113

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hamilton/htypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def get_type_information(some_type: Any) -> Tuple[Type[Type], list]:
# pass


class Parallelizable(typing.Generator[U, None, None], ABC):
class Parallelizable(Generator[U, None, None], ABC):
pass


Expand Down
3 changes: 1 addition & 2 deletions hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@ def from_fn(fn: Callable, name: str = None) -> "Node":
if typing_inspect.is_generic_type(return_type):
if typing_inspect.get_origin(return_type) == Parallelizable:
node_source = NodeType.EXPAND
for parameter in inspect.signature(fn).parameters.values():
hint = parameter.annotation
for hint in typing.get_type_hints(fn, **type_hint_kwargs).values():
if typing_inspect.is_generic_type(hint):
if typing_inspect.get_origin(hint) == Collect:
node_source = NodeType.COLLECT
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ extend-ignore = [
"E501", # line too long
"E721", # Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
"W605", # invalid escape sequence
"TCH001" # TYPE_CHECKING block for first-class imports -- this is a bit ugly for the hamilton codebase
]
exclude = [
"docs/*",
Expand Down
19 changes: 19 additions & 0 deletions tests/resources/nodes_with_future_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from hamilton.htypes import Collect, Parallelizable

"""Tests future annotations with common node types"""


def parallelized() -> Parallelizable[int]:
yield 1
yield 2
yield 3


def standard(parallelized: int) -> int:
return parallelized + 1


def collected(standard: Collect[int]) -> int:
return sum(standard)
20 changes: 20 additions & 0 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import numpy.typing as npt
import pytest

from hamilton import node
from hamilton.node import DependencyType, Node, matches_query

from tests.resources import nodes_with_future_annotation


def test_node_from_fn_happy():
def fn() -> int:
Expand Down Expand Up @@ -117,3 +120,20 @@ def foo(b: BrokenEquals = BrokenEquals()): # noqa

param = DependencyType.from_parameter(inspect.signature(foo).parameters["b"])
assert param == DependencyType.OPTIONAL


# Tests parsing for future annotations
# TODO -- we should generalize this but doing this for specific points is OK for now
def test_node_from_future_annotation_parallelizable():
parallelized = nodes_with_future_annotation.parallelized
assert node.Node.from_fn(parallelized).node_role == node.NodeType.EXPAND


def test_node_from_future_annotation_standard():
standard = nodes_with_future_annotation.standard
assert node.Node.from_fn(standard).node_role == node.NodeType.STANDARD


def test_node_from_future_annotation_collected():
collected = nodes_with_future_annotation.collected
assert node.Node.from_fn(collected).node_role == node.NodeType.COLLECT