From 4199a142d9898369af16871fe25965568ef5e651 Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sat, 31 Aug 2024 00:13:03 -0400 Subject: [PATCH 1/3] Change Parallelizable and Collect to structural subtypes --- hamilton/htypes.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/hamilton/htypes.py b/hamilton/htypes.py index 2ff95ad54..8f1848b82 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -1,8 +1,7 @@ import inspect import sys import typing -from abc import ABC -from typing import Any, Generator, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Iterable, Optional, Protocol, Tuple, Type, TypeVar, Union import typing_inspect @@ -270,8 +269,8 @@ def get_type_information(some_type: Any) -> Tuple[Type[Type], list]: # Type variables for annotations below T = TypeVar("T") -U = TypeVar("U") -V = TypeVar("V") +U = TypeVar("U", covariant=True) +V = TypeVar("V", covariant=True) # TODO -- support sequential operation @@ -279,16 +278,14 @@ def get_type_information(some_type: Any) -> Tuple[Type[Type], list]: # pass -class Parallelizable(Generator[U, None, None], ABC): - pass +class Parallelizable(Iterable[U], Protocol[U]): ... def is_parallelizable_type(type_: Type) -> bool: - return issubclass(type_, Parallelizable) + return _get_origin(type_) == Parallelizable -class Collect(Generator[V, None, None], ABC): - pass +class Collect(Iterable[V], Protocol[V]): ... def check_input_type(node_type: Type, input_value: Any) -> bool: From 6804a0299404b4a5ed7e3ba728a0651c97b2d03f Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sat, 31 Aug 2024 08:02:10 -0400 Subject: [PATCH 2/3] Fix formatting errors --- hamilton/htypes.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/hamilton/htypes.py b/hamilton/htypes.py index 8f1848b82..5a44ed3ea 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -136,7 +136,11 @@ def types_match(param_type: Type[Type], required_node_type: Any) -> bool: _sys_version_info = sys.version_info -_version_tuple = (_sys_version_info.major, _sys_version_info.minor, _sys_version_info.micro) +_version_tuple = ( + _sys_version_info.major, + _sys_version_info.minor, + _sys_version_info.micro, +) """ The following is purely for backwards compatibility @@ -278,14 +282,16 @@ def get_type_information(some_type: Any) -> Tuple[Type[Type], list]: # pass -class Parallelizable(Iterable[U], Protocol[U]): ... +class Parallelizable(Iterable[U], Protocol[U]): + pass def is_parallelizable_type(type_: Type) -> bool: return _get_origin(type_) == Parallelizable -class Collect(Iterable[V], Protocol[V]): ... +class Collect(Iterable[V], Protocol[V]): + pass def check_input_type(node_type: Type, input_value: Any) -> bool: From b2c521de29b60166a6f2476e40f233f65bd4965b Mon Sep 17 00:00:00 2001 From: Charles Swartz Date: Sat, 31 Aug 2024 23:33:37 -0400 Subject: [PATCH 3/3] Update type parameter names and protocol comments --- hamilton/htypes.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/hamilton/htypes.py b/hamilton/htypes.py index 5a44ed3ea..e1db63a26 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -272,17 +272,27 @@ def get_type_information(some_type: Any) -> Tuple[Type[Type], list]: # Type variables for annotations below -T = TypeVar("T") -U = TypeVar("U", covariant=True) -V = TypeVar("V", covariant=True) +SequentialElement = TypeVar("SequentialElement", covariant=True) +ParallelizableElement = TypeVar("ParallelizableElement", covariant=True) +CollectElement = TypeVar("CollectElement", covariant=True) # TODO -- support sequential operation -# class Sequential(Generator[T, None, None], ABC): +# class Sequential(Iterable[SequentialElement], Protocol[SequentialElement]): # pass -class Parallelizable(Iterable[U], Protocol[U]): +class Parallelizable(Iterable[ParallelizableElement], Protocol[ParallelizableElement]): + """Marks the output of a function node as parallelizable. + + Parallelizable outputs are expected to be iterable, where each element dynamically + generates a node. When using dynamic execution, each of these dynamic nodes can be + executed in parallel. + + Because this uses dynamic execution, the builder method `enable_dynamic_execution` + must be called with `allow_experimental_mode=True`. + """ + pass @@ -290,8 +300,15 @@ def is_parallelizable_type(type_: Type) -> bool: return _get_origin(type_) == Parallelizable -class Collect(Iterable[V], Protocol[V]): - pass +class Collect(Iterable[CollectElement], Protocol[CollectElement]): + """Marks a function node parameter as collectable. + + Collectable inputs are expected to be iterable, where each element is populated with + the results of dynamic nodes derived from parallelizable outputs. + + Because this uses dynamic execution, the builder method `enable_dynamic_execution` + must be called with `allow_experimental_mode=True`. + """ def check_input_type(node_type: Type, input_value: Any) -> bool: