diff --git a/hamilton/driver.py b/hamilton/driver.py index 67799bb37..a2ba558d5 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -1,5 +1,7 @@ import abc import functools +import importlib +import importlib.util import json import logging import operator @@ -260,6 +262,33 @@ class Driver: dr = driver.Driver(config, module, adapter=adapter) """ + def __getstate__(self): + """Used for serialization.""" + # Copy the object's state from self.__dict__ + state = self.__dict__.copy() + # Remove the unpicklable entries -- right now it's the modules tracked. + state["__graph_module_names"] = [ + importlib.util.find_spec(m.__name__).name for m in state["graph_modules"] + ] + del state["graph_modules"] # remove from state + return state + + def __setstate__(self, state): + """Used for deserialization.""" + # Restore instance attributes + self.__dict__.update(state) + # Reinitialize the unpicklable entries + # assumption is that the modules are importable in the new process + self.graph_modules = [] + for n in state["__graph_module_names"]: + try: + g_module = importlib.import_module(n) + except ImportError: + logger.error(f"Could not import module {n}") + continue + else: + self.graph_modules.append(g_module) + @staticmethod def normalize_adapter_input( adapter: Optional[ diff --git a/tests/resources/test_driver_serde_mapper.py b/tests/resources/test_driver_serde_mapper.py new file mode 100644 index 000000000..4ddb9f392 --- /dev/null +++ b/tests/resources/test_driver_serde_mapper.py @@ -0,0 +1,33 @@ +from typing import Any + +from hamilton.htypes import Collect, Parallelizable + + +def mapper( + drivers: list, + inputs: list, + final_vars: list = None, +) -> Parallelizable[dict]: + if final_vars is None: + final_vars = [] + for dr, input_ in zip(drivers, inputs): + yield { + "dr": dr, + "final_vars": final_vars or dr.list_available_variables(), + "input": input_, + } + + +def inside(mapper: dict) -> dict: + _dr = mapper["dr"] + _inputs = mapper["input"] + _final_var = mapper["final_vars"] + return _dr.execute(final_vars=_final_var, inputs=_inputs) + + +def passthrough(inside: dict) -> dict: + return inside + + +def reducer(passthrough: Collect[dict]) -> Any: + return passthrough diff --git a/tests/resources/test_driver_serde_worker.py b/tests/resources/test_driver_serde_worker.py new file mode 100644 index 000000000..9b4b421a1 --- /dev/null +++ b/tests/resources/test_driver_serde_worker.py @@ -0,0 +1,2 @@ +def double(a: int) -> int: + return a * 2 diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index 54443fd1c..23fe8ac71 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -19,6 +19,8 @@ import tests.resources.dynamic_parallelism.parallel_linear_basic import tests.resources.tagging import tests.resources.test_default_args +import tests.resources.test_driver_serde_mapper +import tests.resources.test_driver_serde_worker import tests.resources.test_for_materialization import tests.resources.very_simple_dag @@ -665,3 +667,28 @@ def func_to_test(a: int) -> int: assert v.tags == n.tags assert v.documentation == n.documentation == "This is a doctstring" assert v.originating_functions == n.originating_functions + + +def test_driver_setstate_getstate(): + """This is an integration test testing serializability of the hamilton driver.""" + from hamilton.execution import executors + + drivers = [] + inputs = [] + for i in range(4): + dr = Builder().with_modules(tests.resources.test_driver_serde_worker).build() + drivers.append(dr) + inputs.append({"a": i}) + + dr = ( + Builder() + .with_modules(tests.resources.test_driver_serde_mapper) + .enable_dynamic_execution(allow_experimental_mode=True) + .with_remote_executor(executors.MultiProcessingExecutor(4)) + .build() + ) + r = dr.execute( + final_vars=["reducer"], + inputs={"drivers": drivers, "inputs": inputs, "final_vars": ["double"]}, + ) + assert r == {"reducer": [{"double": 0}, {"double": 2}, {"double": 4}, {"double": 6}]}