From 4f136d40d88c0f9c06acdb41b7a11da689543f3e Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Thu, 2 Jan 2025 09:46:19 -0800 Subject: [PATCH] Adds wrapping a result builder So that people can adjust the result accordingly. --- hamilton/plugins/h_threadpool.py | 31 +++++++++++++++- tests/plugins/test_h_threadpool.py | 59 ++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/hamilton/plugins/h_threadpool.py b/hamilton/plugins/h_threadpool.py index 644c41fcf..7f275095e 100644 --- a/hamilton/plugins/h_threadpool.py +++ b/hamilton/plugins/h_threadpool.py @@ -1,5 +1,5 @@ from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List, Type from hamilton import registry @@ -45,14 +45,39 @@ class FutureAdapter(base.BaseDoRemoteExecute, lifecycle.ResultBuilder): """ - def __init__(self, max_workers: int = None, thread_name_prefix: str = ""): + def __init__( + self, + max_workers: int = None, + thread_name_prefix: str = "", + result_builder: lifecycle.ResultBuilder = None, + ): """Constructor. :param max_workers: The maximum number of threads that can be used to execute the given calls. :param thread_name_prefix: An optional name prefix to give our threads. + :param result_builder: Optional. Result builder to use for building the result. """ self.executor = ThreadPoolExecutor( max_workers=max_workers, thread_name_prefix=thread_name_prefix ) + self.result_builder = result_builder + + def input_types(self) -> List[Type[Type]]: + """Gives the applicable types to this result builder. + This is optional for backwards compatibility, but is recommended. + + :return: A list of types that this can apply to. + """ + # since this wraps a potential result builder, expose the input types of the wrapped + # result builder doesn't make sense. + return [Any] + + def output_type(self) -> Type: + """Returns the output type of this result builder + :return: the type that this creates + """ + if self.result_builder: + return self.result_builder.output_type() + return Any def do_remote_execute( self, @@ -81,4 +106,6 @@ def build_result(self, **outputs: Any) -> Any: for k, v in outputs.items(): if isinstance(v, Future): outputs[k] = v.result() + if self.result_builder: + return self.result_builder.build_result(**outputs) return outputs diff --git a/tests/plugins/test_h_threadpool.py b/tests/plugins/test_h_threadpool.py index ead9f9a3a..a92c8ca33 100644 --- a/tests/plugins/test_h_threadpool.py +++ b/tests/plugins/test_h_threadpool.py @@ -1,5 +1,7 @@ from concurrent.futures import Future +from typing import Any +from hamilton import lifecycle from hamilton.plugins.h_threadpool import FutureAdapter, _new_fn @@ -58,3 +60,60 @@ def test_future_adapter_build_result(): result = adapter.build_result(a=future_a, b=future_b) assert result == {"a": 1, "b": 2} + + +def test_future_adapter_input_types(): + adapter = FutureAdapter() + assert adapter.input_types() == [Any] + + +def test_future_adapter_output_type(): + adapter = FutureAdapter() + assert adapter.output_type() == Any + + +def test_future_adapter_input_types_with_result_builder(): + """Tests that we ignore exposing the input types of the wrapped result builder.""" + + class MockResultBuilder(lifecycle.ResultBuilder): + def build_result(self, **outputs: Any) -> Any: + pass + + def input_types(self): + return [int, str] + + adapter = FutureAdapter(result_builder=MockResultBuilder()) + assert adapter.input_types() == [Any] + + +def test_future_adapter_output_type_with_result_builder(): + class MockResultBuilder(lifecycle.ResultBuilder): + def build_result(self, **outputs: Any) -> Any: + pass + + def output_type(self): + return dict + + adapter = FutureAdapter(result_builder=MockResultBuilder()) + assert adapter.output_type() == dict + + +def test_future_adapter_build_result_with_result_builder(): + class MockResultBuilder(lifecycle.ResultBuilder): + def build_result(self, **outputs): + return sum(outputs.values()) + + def input_types(self): + return [int] + + def output_type(self): + return int + + adapter = FutureAdapter(result_builder=MockResultBuilder()) + future_a = Future() + future_b = Future() + future_a.set_result(1) + future_b.set_result(2) + + result = adapter.build_result(a=future_a, b=future_b) + assert result == 3