Skip to content

Commit

Permalink
Adds wrapping a result builder
Browse files Browse the repository at this point in the history
So that people can adjust the result accordingly.
  • Loading branch information
skrawcz authored and elijahbenizzy committed Jan 3, 2025
1 parent 7f36808 commit 161b2a6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
31 changes: 29 additions & 2 deletions hamilton/plugins/h_threadpool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, List, Type

from hamilton import registry

Expand Down Expand Up @@ -45,14 +45,39 @@ class FutureAdapter(base.BaseDoRemoteExecute, lifecycle.ResultBuilder):
"""

def __init__(self, max_workers: int = None, thread_name_prefix: str = ""):
def __init__(
self,
max_workers: int = None,
thread_name_prefix: str = "",
result_builder: lifecycle.ResultBuilder = None,
):
"""Constructor.
:param max_workers: The maximum number of threads that can be used to execute the given calls.
:param thread_name_prefix: An optional name prefix to give our threads.
:param result_builder: Optional. Result builder to use for building the result.
"""
self.executor = ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=thread_name_prefix
)
self.result_builder = result_builder

def input_types(self) -> List[Type[Type]]:
"""Gives the applicable types to this result builder.
This is optional for backwards compatibility, but is recommended.
:return: A list of types that this can apply to.
"""
# since this wraps a potential result builder, expose the input types of the wrapped
# result builder doesn't make sense.
return [Any]

def output_type(self) -> Type:
"""Returns the output type of this result builder
:return: the type that this creates
"""
if self.result_builder:
return self.result_builder.output_type()
return Any

def do_remote_execute(
self,
Expand Down Expand Up @@ -81,4 +106,6 @@ def build_result(self, **outputs: Any) -> Any:
for k, v in outputs.items():
if isinstance(v, Future):
outputs[k] = v.result()
if self.result_builder:
return self.result_builder.build_result(**outputs)
return outputs
59 changes: 59 additions & 0 deletions tests/plugins/test_h_threadpool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from concurrent.futures import Future
from typing import Any

from hamilton import lifecycle
from hamilton.plugins.h_threadpool import FutureAdapter, _new_fn


Expand Down Expand Up @@ -58,3 +60,60 @@ def test_future_adapter_build_result():

result = adapter.build_result(a=future_a, b=future_b)
assert result == {"a": 1, "b": 2}


def test_future_adapter_input_types():
adapter = FutureAdapter()
assert adapter.input_types() == [Any]


def test_future_adapter_output_type():
adapter = FutureAdapter()
assert adapter.output_type() == Any


def test_future_adapter_input_types_with_result_builder():
"""Tests that we ignore exposing the input types of the wrapped result builder."""

class MockResultBuilder(lifecycle.ResultBuilder):
def build_result(self, **outputs: Any) -> Any:
pass

def input_types(self):
return [int, str]

adapter = FutureAdapter(result_builder=MockResultBuilder())
assert adapter.input_types() == [Any]


def test_future_adapter_output_type_with_result_builder():
class MockResultBuilder(lifecycle.ResultBuilder):
def build_result(self, **outputs: Any) -> Any:
pass

def output_type(self):
return dict

adapter = FutureAdapter(result_builder=MockResultBuilder())
assert adapter.output_type() == dict


def test_future_adapter_build_result_with_result_builder():
class MockResultBuilder(lifecycle.ResultBuilder):
def build_result(self, **outputs):
return sum(outputs.values())

def input_types(self):
return [int]

def output_type(self):
return int

adapter = FutureAdapter(result_builder=MockResultBuilder())
future_a = Future()
future_b = Future()
future_a.set_result(1)
future_b.set_result(2)

result = adapter.build_result(a=future_a, b=future_b)
assert result == 3

0 comments on commit 161b2a6

Please sign in to comment.