diff --git a/docs/concepts/parallel-task.rst b/docs/concepts/parallel-task.rst index 6159b753d..1b09a5c82 100644 --- a/docs/concepts/parallel-task.rst +++ b/docs/concepts/parallel-task.rst @@ -1,6 +1,48 @@ Dynamic DAGs/Parallel Execution ---------------------------------- +There are two approaches to parallel execution in Hamilton: + +1. Using an adapter that submits each node/function to a system that handles execution, e.g. ray, dask, async, or a threadpool. +2. Using the `Parallelizable[]` and `Collect[]` types + delegating to an executor. + +Using an Adapter +================ +The adapter approach effectively farms out the execution of each node/function to a system that can handle resolving +futures. That is, Hamilton walks the DAG and submits each node to the adapter, which then submits the node for execution, +and internally the execution resolves any Futures from prior submitted nodes. + +To make use of this, the general pattern is you apply an adapter to the driver and don't need to touch your Hamilton functions!: + +.. code-block:: python + + from hamilton import driver + from hamilton.execution import executors + from hamilton.plugins.h_threadpool import FutureAdapter + # from hamilton.plugins.h_ray import RayGraphAdapter + # from hamilton.plugins.h_dask import DaskGraphAdapter + + dr = ( + driver.Builder() + .with_modules(foo_module) + .with_adapter(FutureAdapter()) + .build() + ) + + dr.execute(["my_variable"], inputs={...}, overrides={...}) + +The code above will execute the DAG submitting to a `ThreadPoolExecutor` (see :doc:`../reference/graph-adapters/ThreadPoolFutureAdapter`), +which is great if you're doing a lot of I/O bound work, e.g. making API calls, reading from a database, etc. + +See this `Threadpool based example `_ for a complete example. + +Other adapters, e.g. Ray :doc:`../reference/graph-adapters/RayGraphAdapter`, Dask :doc:`../reference/graph-adapters/DaskGraphAdapter`, etc... will submit to their respective executors, but will involve object serialization +(see caveats below). + +Using the `Parallelizable[]` and `Collect[]` types +================================================== + + Hamilton now has pluggable execution, which allows for the following: 1. Grouping of nodes into "tasks" (discrete execution unit between serialization boundaries) diff --git a/docs/reference/graph-adapters/ThreadPoolFutureAdapter.rst b/docs/reference/graph-adapters/ThreadPoolFutureAdapter.rst new file mode 100644 index 000000000..ca34d1461 --- /dev/null +++ b/docs/reference/graph-adapters/ThreadPoolFutureAdapter.rst @@ -0,0 +1,11 @@ +========================== +h_threadpool.FutureAdapter +========================== + +This is an adapter to delegate execution of the individual nodes in a Hamilton graph to a threadpool. +This is useful when you have a graph with many nodes that can be executed in parallel. + +.. autoclass:: hamilton.plugins.h_threadpool.FutureAdapter + :special-members: __init__ + :members: + :inherited-members: diff --git a/docs/reference/graph-adapters/index.rst b/docs/reference/graph-adapters/index.rst index 4f9f54de4..ffb43dd66 100644 --- a/docs/reference/graph-adapters/index.rst +++ b/docs/reference/graph-adapters/index.rst @@ -15,6 +15,7 @@ Reference SimplePythonGraphAdapter HamiltonGraphAdapter AsyncGraphAdapter + ThreadPoolFutureAdapter CachingGraphAdapter DaskGraphAdapter PySparkUDFGraphAdapter diff --git a/examples/parallelism/lazy_threadpool_execution/README.md b/examples/parallelism/lazy_threadpool_execution/README.md new file mode 100644 index 000000000..f82b70f33 --- /dev/null +++ b/examples/parallelism/lazy_threadpool_execution/README.md @@ -0,0 +1,53 @@ +# Lazy threadpool execution + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/parallelism/lazy_threadpool_execution/notebook.ipynb) + +This example is different from the other examples under /parallelism/ in that +it demonstrates how to use an adapter to put each +function into a threadpool that allows for lazy DAG evaluation and for parallelism +to be achieved. This is useful when you have a lot of +functions doing I/O bound tasks and you want to speed +up the execution of your program. E.g. doing lots of +HTTP requests, reading/writing to disk, LLM API calls, etc. + +> Note: this adapter does not support DAGs with Parallelizable and Collect functions; create an issue if you need this feature. + +![DAG](my_functions.png) + +The above image shows the DAG that will be executed. You can see from the structure +that the DAG can be parallelized, i.e. the left most nodes can be executed in parallel. + +When you execute `run.py`, you will output that shows: + +1. The DAG running in parallel -- check the image against what is printed. +2. The DAG logging to the Hamilton UI -- please adjust for you project. +3. The DAG running without the adapter -- this is to show the difference in execution time. +4. An async version of the DAG running in parallel -- this is to show that the performance of this approach is similar. + +```bash +python run.py +``` + +To use this adapter: + +```python +from hamilton import driver +from hamilton.plugins import h_threadpool + +# import your hamilton functions +import my_functions + +# Create the adapter +adapter = h_threadpool.FutureAdapter() + +# Create a driver +dr = ( + driver.Builder() + .with_modules(my_functions) + .with_adapters(adapter) + .build() +) +# execute +dr.execute(["s", "x", "a"]) # if the DAG can be parallelized it will be + +``` diff --git a/examples/parallelism/lazy_threadpool_execution/my_functions.png b/examples/parallelism/lazy_threadpool_execution/my_functions.png new file mode 100644 index 000000000..cf4bd1927 Binary files /dev/null and b/examples/parallelism/lazy_threadpool_execution/my_functions.png differ diff --git a/examples/parallelism/lazy_threadpool_execution/my_functions.py b/examples/parallelism/lazy_threadpool_execution/my_functions.py new file mode 100644 index 000000000..59e9b9762 --- /dev/null +++ b/examples/parallelism/lazy_threadpool_execution/my_functions.py @@ -0,0 +1,55 @@ +import time + + +def a() -> str: + print("a") + time.sleep(3) + return "a" + + +def b() -> str: + print("b") + time.sleep(3) + return "b" + + +def c(a: str, b: str) -> str: + print("c") + time.sleep(3) + return a + " " + b + + +def d() -> str: + print("d") + time.sleep(3) + return "d" + + +def e(c: str, d: str) -> str: + print("e") + time.sleep(3) + return c + " " + d + + +def z() -> str: + print("z") + time.sleep(3) + return "z" + + +def y() -> str: + print("y") + time.sleep(3) + return "y" + + +def x(z: str, y: str) -> str: + print("x") + time.sleep(3) + return z + " " + y + + +def s(x: str, e: str) -> str: + print("s") + time.sleep(3) + return x + " " + e diff --git a/examples/parallelism/lazy_threadpool_execution/my_functions_async.png b/examples/parallelism/lazy_threadpool_execution/my_functions_async.png new file mode 100644 index 000000000..cf4bd1927 Binary files /dev/null and b/examples/parallelism/lazy_threadpool_execution/my_functions_async.png differ diff --git a/examples/parallelism/lazy_threadpool_execution/my_functions_async.py b/examples/parallelism/lazy_threadpool_execution/my_functions_async.py new file mode 100644 index 000000000..a639ed963 --- /dev/null +++ b/examples/parallelism/lazy_threadpool_execution/my_functions_async.py @@ -0,0 +1,55 @@ +import asyncio + + +async def a() -> str: + print("a") + await asyncio.sleep(3) + return "a" + + +async def b() -> str: + print("b") + await asyncio.sleep(3) + return "b" + + +async def c(a: str, b: str) -> str: + print("c") + await asyncio.sleep(3) + return a + " " + b + + +async def d() -> str: + print("d") + await asyncio.sleep(3) + return "d" + + +async def e(c: str, d: str) -> str: + print("e") + await asyncio.sleep(3) + return c + " " + d + + +async def z() -> str: + print("z") + await asyncio.sleep(3) + return "z" + + +async def y() -> str: + print("y") + await asyncio.sleep(3) + return "y" + + +async def x(z: str, y: str) -> str: + print("x") + await asyncio.sleep(3) + return z + " " + y + + +async def s(x: str, e: str) -> str: + print("s") + await asyncio.sleep(3) + return x + " " + e diff --git a/examples/parallelism/lazy_threadpool_execution/notebook.ipynb b/examples/parallelism/lazy_threadpool_execution/notebook.ipynb new file mode 100644 index 000000000..2b66908b9 --- /dev/null +++ b/examples/parallelism/lazy_threadpool_execution/notebook.ipynb @@ -0,0 +1,325 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Execute this cell to install dependencies\n", + "%pip install sf-hamilton[visualization]" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# run me in google colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/parallelism/lazy_threadpool_execution/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/parallelism/lazy_threadpool_execution/notebook.ipynb)\n", + "id": "7b55978b426b6e42" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-02T05:00:01.808537Z", + "start_time": "2025-01-02T04:59:53.847872Z" + } + }, + "cell_type": "code", + "source": "%load_ext hamilton.plugins.jupyter_magic", + "id": "a1f2f8937a0b0488", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n" + ] + } + ], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Create a module with some functions\n", + "This hopefully shows a good example of what could be parallelized given the structure of the DAG." + ], + "id": "32c01561386fc348" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-02T05:00:16.514305Z", + "start_time": "2025-01-02T05:00:16.101379Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module my_functions --display\n", + "\n", + "import time\n", + "\n", + "\n", + "def a() -> str:\n", + " print(\"a\")\n", + " time.sleep(3)\n", + " return \"a\"\n", + "\n", + "\n", + "def b() -> str:\n", + " print(\"b\")\n", + " time.sleep(3)\n", + " return \"b\"\n", + "\n", + "\n", + "def c(a: str, b: str) -> str:\n", + " print(\"c\")\n", + " time.sleep(3)\n", + " return a + \" \" + b\n", + "\n", + "\n", + "def d() -> str:\n", + " print(\"d\")\n", + " time.sleep(3)\n", + " return \"d\"\n", + "\n", + "\n", + "def e(c: str, d: str) -> str:\n", + " print(\"e\")\n", + " time.sleep(3)\n", + " return c + \" \" + d\n", + "\n", + "\n", + "def z() -> str:\n", + " print(\"z\")\n", + " time.sleep(3)\n", + " return \"z\"\n", + "\n", + "\n", + "def y() -> str:\n", + " print(\"y\")\n", + " time.sleep(3)\n", + " return \"y\"\n", + "\n", + "\n", + "def x(z: str, y: str) -> str:\n", + " print(\"x\")\n", + " time.sleep(3)\n", + " return z + \" \" + y\n", + "\n", + "\n", + "def s(x: str, e: str) -> str:\n", + " print(\"s\")\n", + " time.sleep(3)\n", + " return x + \" \" + e\n", + "\n" + ], + "id": "8e0e3b7a96ca1d44", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nc\n\nc\nstr\n\n\n\ne\n\ne\nstr\n\n\n\nc->e\n\n\n\n\n\nx\n\nx\nstr\n\n\n\ns\n\ns\nstr\n\n\n\nx->s\n\n\n\n\n\na\n\na\nstr\n\n\n\na->c\n\n\n\n\n\ne->s\n\n\n\n\n\nb\n\nb\nstr\n\n\n\nb->c\n\n\n\n\n\nz\n\nz\nstr\n\n\n\nz->x\n\n\n\n\n\nd\n\nd\nstr\n\n\n\nd->e\n\n\n\n\n\ny\n\ny\nstr\n\n\n\ny->x\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Run it without the adapter", + "id": "60355598274f9b79" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-02T05:02:44.719265Z", + "start_time": "2025-01-02T05:02:44.423066Z" + } + }, + "cell_type": "code", + "source": [ + "from hamilton import driver\n", + "dr = driver.Builder().with_modules(my_functions).build()\n", + "dr" + ], + "id": "fcb0677daf5b4a31", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nc\n\nc\nstr\n\n\n\ne\n\ne\nstr\n\n\n\nc->e\n\n\n\n\n\nx\n\nx\nstr\n\n\n\ns\n\ns\nstr\n\n\n\nx->s\n\n\n\n\n\na\n\na\nstr\n\n\n\na->c\n\n\n\n\n\ne->s\n\n\n\n\n\nb\n\nb\nstr\n\n\n\nb->c\n\n\n\n\n\nz\n\nz\nstr\n\n\n\nz->x\n\n\n\n\n\nd\n\nd\nstr\n\n\n\nd->e\n\n\n\n\n\ny\n\ny\nstr\n\n\n\ny->x\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-02T05:03:18.620774Z", + "start_time": "2025-01-02T05:02:51.536385Z" + } + }, + "cell_type": "code", + "source": [ + "start = time.time()\n", + "r = dr.execute([\"s\", \"x\", \"a\"])\n", + "print(\"got return from dr\")\n", + "print(r)\n", + "print(\"Time taken with\", time.time() - start)" + ], + "id": "960f9f5d5f018b38", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "z\n", + "y\n", + "x\n", + "a\n", + "b\n", + "c\n", + "d\n", + "e\n", + "s\n", + "got return from dr\n", + "{'s': 'z y a b d', 'x': 'z y', 'a': 'a'}\n", + "Time taken with 27.080925941467285\n" + ] + } + ], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Run it with the adapter -- note the parallelism & time taken", + "id": "8a1d2b183b914034" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-02T05:03:18.904861Z", + "start_time": "2025-01-02T05:03:18.632385Z" + } + }, + "cell_type": "code", + "source": [ + "from hamilton import driver\n", + "from hamilton.plugins import h_threadpool\n", + "\n", + "adapter = h_threadpool.FutureAdapter()\n", + "dr = driver.Builder().with_modules(my_functions).with_adapters(adapter).build()\n", + "dr" + ], + "id": "63853f111ef28439", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nc\n\nc\nstr\n\n\n\ne\n\ne\nstr\n\n\n\nc->e\n\n\n\n\n\nx\n\nx\nstr\n\n\n\ns\n\ns\nstr\n\n\n\nx->s\n\n\n\n\n\na\n\na\nstr\n\n\n\na->c\n\n\n\n\n\ne->s\n\n\n\n\n\nb\n\nb\nstr\n\n\n\nb->c\n\n\n\n\n\nz\n\nz\nstr\n\n\n\nz->x\n\n\n\n\n\nd\n\nd\nstr\n\n\n\nd->e\n\n\n\n\n\ny\n\ny\nstr\n\n\n\ny->x\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-02T05:03:30.949086Z", + "start_time": "2025-01-02T05:03:18.925667Z" + } + }, + "cell_type": "code", + "source": [ + "start = time.time()\n", + "r = dr.execute([\"s\", \"x\", \"a\"])\n", + "print(\"got return from dr\")\n", + "print(r)\n", + "print(\"Time taken with\", time.time() - start)" + ], + "id": "1bb057f4277705de", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "z\n", + "y\n", + "a\n", + "b\n", + "d\n", + "x\n", + "c\n", + "e\n", + "s\n", + "got return from dr\n", + "{'s': 'z y a b d', 'x': 'z y', 'a': 'a'}\n", + "Time taken with 12.019250869750977\n" + ] + } + ], + "execution_count": 9 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# 27 seconds vs 12 seconds\n", + "\n", + "With the adapter we see a significant improvement in time taken to execute the DAG. This is because the adapter is able to parallelize the execution." + ], + "id": "56e9fb7445639984" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "e31132a4fd211887" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/parallelism/lazy_threadpool_execution/requirements.txt b/examples/parallelism/lazy_threadpool_execution/requirements.txt new file mode 100644 index 000000000..0bd997a6e --- /dev/null +++ b/examples/parallelism/lazy_threadpool_execution/requirements.txt @@ -0,0 +1,3 @@ +sf-hamilton-sdk # optional +sf-hamilton-ui # optional +sf-hamilton[visualization] diff --git a/examples/parallelism/lazy_threadpool_execution/run.py b/examples/parallelism/lazy_threadpool_execution/run.py new file mode 100644 index 000000000..734842c39 --- /dev/null +++ b/examples/parallelism/lazy_threadpool_execution/run.py @@ -0,0 +1,73 @@ +import time + +import my_functions + +from hamilton import async_driver, driver +from hamilton.plugins import h_threadpool + +start = time.time() +adapter = h_threadpool.FutureAdapter() +dr = driver.Builder().with_modules(my_functions).with_adapters(adapter).build() +dr.display_all_functions("my_functions.png") +r = dr.execute(["s", "x", "a"]) +print("got return from dr") +print(r) +print("Time taken with", time.time() - start) + +from hamilton_sdk import adapters + +tracker = adapters.HamiltonTracker( + project_id=21, # modify this as needed + username="elijah@dagworks.io", + dag_name="with_caching", + tags={"environment": "DEV", "cached": "False", "team": "MY_TEAM", "version": "1"}, +) + +start = time.time() +dr = ( + driver.Builder().with_modules(my_functions).with_adapters(tracker, adapter).with_cache().build() +) +r = dr.execute(["s", "x", "a"]) +print("got return from dr") +print(r) +print("Time taken with cold cache", time.time() - start) + +tracker = adapters.HamiltonTracker( + project_id=21, # modify this as needed + username="elijah@dagworks.io", + dag_name="with_caching", + tags={"environment": "DEV", "cached": "True", "team": "MY_TEAM", "version": "1"}, +) + +start = time.time() +dr = ( + driver.Builder().with_modules(my_functions).with_adapters(tracker, adapter).with_cache().build() +) +r = dr.execute(["s", "x", "a"]) +print("got return from dr") +print(r) +print("Time taken with warm cache", time.time() - start) + +start = time.time() +dr = driver.Builder().with_modules(my_functions).build() +r = dr.execute(["s", "x", "a"]) +print("got return from dr") +print(r) +print("Time taken without", time.time() - start) + + +async def run_async(): + import my_functions_async + + start = time.time() + dr = await async_driver.Builder().with_modules(my_functions_async).build() + dr.display_all_functions("my_functions_async.png") + r = await dr.execute(["s", "x", "a"]) + print("got return from dr") + print(r) + print("Async Time taken without", time.time() - start) + + +import asyncio + +asyncio.run(run_async()) diff --git a/hamilton/plugins/h_threadpool.py b/hamilton/plugins/h_threadpool.py new file mode 100644 index 000000000..7f275095e --- /dev/null +++ b/hamilton/plugins/h_threadpool.py @@ -0,0 +1,111 @@ +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Callable, Dict, List, Type + +from hamilton import registry + +registry.disable_autoload() + +from hamilton import lifecycle, node +from hamilton.lifecycle import base + + +def _new_fn(fn: Callable, **fn_kwargs) -> Any: + """Function that runs in the thread. + + It can recursively check for Futures because we don't have to worry about + process serialization. + :param fn: Function to run + :param fn_kwargs: Keyword arguments to pass to the function + """ + for k, v in fn_kwargs.items(): + if isinstance(v, Future): + while isinstance(v, Future): + v = v.result() # this blocks until the future is resolved + fn_kwargs[k] = v + # execute the function once all the futures are resolved + return fn(**fn_kwargs) + + +class FutureAdapter(base.BaseDoRemoteExecute, lifecycle.ResultBuilder): + """Adapter that lazily submits each function for execution to a ThreadpoolExecutor. + + This adapter has similar behavior to the async Hamilton driver which allows for parallel execution of functions. + + This adapter works because we don't have to worry about object serialization. + + Caveats: + - DAGs with lots of CPU intense functions will limit usefulness of this adapter, unless they release the GIL. + - DAGs with lots of I/O bound work will benefit from this adapter, e.g. making API calls. + - The max parallelism is limited by the number of threads in the ThreadPoolExecutor. + + Unsupported behavior: + - The FutureAdapter does not support DAGs with Parallelizable & Collect functions. This is due to laziness + rather than anything inherently technical. If you'd like this feature, please open an issue on the Hamilton + repository. + + """ + + def __init__( + self, + max_workers: int = None, + thread_name_prefix: str = "", + result_builder: lifecycle.ResultBuilder = None, + ): + """Constructor. + :param max_workers: The maximum number of threads that can be used to execute the given calls. + :param thread_name_prefix: An optional name prefix to give our threads. + :param result_builder: Optional. Result builder to use for building the result. + """ + self.executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix=thread_name_prefix + ) + self.result_builder = result_builder + + def input_types(self) -> List[Type[Type]]: + """Gives the applicable types to this result builder. + This is optional for backwards compatibility, but is recommended. + + :return: A list of types that this can apply to. + """ + # since this wraps a potential result builder, expose the input types of the wrapped + # result builder doesn't make sense. + return [Any] + + def output_type(self) -> Type: + """Returns the output type of this result builder + :return: the type that this creates + """ + if self.result_builder: + return self.result_builder.output_type() + return Any + + def do_remote_execute( + self, + *, + execute_lifecycle_for_node: Callable, + node: node.Node, + **kwargs: Dict[str, Any], + ) -> Any: + """Function that submits the passed in function to the ThreadPoolExecutor to be executed + after wrapping it with the _new_fn function. + + :param node: Node that is being executed + :param execute_lifecycle_for_node: Function executing lifecycle_hooks and lifecycle_methods + :param kwargs: Keyword arguments that are being passed into the function + """ + return self.executor.submit(_new_fn, execute_lifecycle_for_node, **kwargs) + + def build_result(self, **outputs: Any) -> Any: + """Given a set of outputs, build the result. + + This function will block until all futures are resolved. + + :param outputs: the outputs from the execution of the graph. + :return: the result of the execution of the graph. + """ + for k, v in outputs.items(): + if isinstance(v, Future): + outputs[k] = v.result() + if self.result_builder: + return self.result_builder.build_result(**outputs) + return outputs diff --git a/tests/plugins/test_h_threadpool.py b/tests/plugins/test_h_threadpool.py new file mode 100644 index 000000000..a92c8ca33 --- /dev/null +++ b/tests/plugins/test_h_threadpool.py @@ -0,0 +1,119 @@ +from concurrent.futures import Future +from typing import Any + +from hamilton import lifecycle +from hamilton.plugins.h_threadpool import FutureAdapter, _new_fn + + +def test_new_fn_with_no_futures(): + def sample_fn(a, b): + return a + b + + result = _new_fn(sample_fn, a=1, b=2) + assert result == 3 + + +def test_new_fn_with_futures(): + def sample_fn(a, b): + return a + b + + future_a = Future() + future_b = Future() + future_a.set_result(1) + future_b.set_result(2) + + result = _new_fn(sample_fn, a=future_a, b=future_b) + assert result == 3 + + +def test_future_adapter_do_remote_execute(): + def sample_fn(a, b): + return a + b + + adapter = FutureAdapter(max_workers=2) + future = adapter.do_remote_execute(execute_lifecycle_for_node=sample_fn, node=None, a=1, b=2) + assert future.result() == 3 + + +def test_future_adapter_do_remote_execute_with_futures(): + def sample_fn(a, b): + return a + b + + future_a = Future() + future_b = Future() + future_a.set_result(1) + future_b.set_result(2) + + adapter = FutureAdapter(max_workers=2) + future = adapter.do_remote_execute( + execute_lifecycle_for_node=sample_fn, node=None, a=future_a, b=future_b + ) + assert future.result() == 3 + + +def test_future_adapter_build_result(): + adapter = FutureAdapter(max_workers=2) + future_a = Future() + future_b = Future() + future_a.set_result(1) + future_b.set_result(2) + + result = adapter.build_result(a=future_a, b=future_b) + assert result == {"a": 1, "b": 2} + + +def test_future_adapter_input_types(): + adapter = FutureAdapter() + assert adapter.input_types() == [Any] + + +def test_future_adapter_output_type(): + adapter = FutureAdapter() + assert adapter.output_type() == Any + + +def test_future_adapter_input_types_with_result_builder(): + """Tests that we ignore exposing the input types of the wrapped result builder.""" + + class MockResultBuilder(lifecycle.ResultBuilder): + def build_result(self, **outputs: Any) -> Any: + pass + + def input_types(self): + return [int, str] + + adapter = FutureAdapter(result_builder=MockResultBuilder()) + assert adapter.input_types() == [Any] + + +def test_future_adapter_output_type_with_result_builder(): + class MockResultBuilder(lifecycle.ResultBuilder): + def build_result(self, **outputs: Any) -> Any: + pass + + def output_type(self): + return dict + + adapter = FutureAdapter(result_builder=MockResultBuilder()) + assert adapter.output_type() == dict + + +def test_future_adapter_build_result_with_result_builder(): + class MockResultBuilder(lifecycle.ResultBuilder): + def build_result(self, **outputs): + return sum(outputs.values()) + + def input_types(self): + return [int] + + def output_type(self): + return int + + adapter = FutureAdapter(result_builder=MockResultBuilder()) + future_a = Future() + future_b = Future() + future_a.set_result(1) + future_b.set_result(2) + + result = adapter.build_result(a=future_a, b=future_b) + assert result == 3