Skip to content

Commit

Permalink
Adds setstate getstate to driver and fixes 1093 (#1100)
Browse files Browse the repository at this point in the history
* Adds setstate getstate to driver and fixes 1093

This fixes #1093.

Modules were not picklable. So this fixes that by serializing
their fully qualified names, and then when the driver object
is deserialized, they are reinstantiated as module objects.
  • Loading branch information
skrawcz authored Aug 20, 2024
1 parent 3ce39a1 commit 05161d7
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 0 deletions.
29 changes: 29 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
import functools
import importlib
import importlib.util
import json
import logging
import operator
Expand Down Expand Up @@ -260,6 +262,33 @@ class Driver:
dr = driver.Driver(config, module, adapter=adapter)
"""

def __getstate__(self):
"""Used for serialization."""
# Copy the object's state from self.__dict__
state = self.__dict__.copy()
# Remove the unpicklable entries -- right now it's the modules tracked.
state["__graph_module_names"] = [
importlib.util.find_spec(m.__name__).name for m in state["graph_modules"]
]
del state["graph_modules"] # remove from state
return state

def __setstate__(self, state):
"""Used for deserialization."""
# Restore instance attributes
self.__dict__.update(state)
# Reinitialize the unpicklable entries
# assumption is that the modules are importable in the new process
self.graph_modules = []
for n in state["__graph_module_names"]:
try:
g_module = importlib.import_module(n)
except ImportError:
logger.error(f"Could not import module {n}")
continue
else:
self.graph_modules.append(g_module)

@staticmethod
def normalize_adapter_input(
adapter: Optional[
Expand Down
33 changes: 33 additions & 0 deletions tests/resources/test_driver_serde_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any

from hamilton.htypes import Collect, Parallelizable


def mapper(
drivers: list,
inputs: list,
final_vars: list = None,
) -> Parallelizable[dict]:
if final_vars is None:
final_vars = []
for dr, input_ in zip(drivers, inputs):
yield {
"dr": dr,
"final_vars": final_vars or dr.list_available_variables(),
"input": input_,
}


def inside(mapper: dict) -> dict:
_dr = mapper["dr"]
_inputs = mapper["input"]
_final_var = mapper["final_vars"]
return _dr.execute(final_vars=_final_var, inputs=_inputs)


def passthrough(inside: dict) -> dict:
return inside


def reducer(passthrough: Collect[dict]) -> Any:
return passthrough
2 changes: 2 additions & 0 deletions tests/resources/test_driver_serde_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def double(a: int) -> int:
return a * 2
27 changes: 27 additions & 0 deletions tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import tests.resources.dynamic_parallelism.parallel_linear_basic
import tests.resources.tagging
import tests.resources.test_default_args
import tests.resources.test_driver_serde_mapper
import tests.resources.test_driver_serde_worker
import tests.resources.test_for_materialization
import tests.resources.very_simple_dag

Expand Down Expand Up @@ -665,3 +667,28 @@ def func_to_test(a: int) -> int:
assert v.tags == n.tags
assert v.documentation == n.documentation == "This is a doctstring"
assert v.originating_functions == n.originating_functions


def test_driver_setstate_getstate():
"""This is an integration test testing serializability of the hamilton driver."""
from hamilton.execution import executors

drivers = []
inputs = []
for i in range(4):
dr = Builder().with_modules(tests.resources.test_driver_serde_worker).build()
drivers.append(dr)
inputs.append({"a": i})

dr = (
Builder()
.with_modules(tests.resources.test_driver_serde_mapper)
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executors.MultiProcessingExecutor(4))
.build()
)
r = dr.execute(
final_vars=["reducer"],
inputs={"drivers": drivers, "inputs": inputs, "final_vars": ["double"]},
)
assert r == {"reducer": [{"double": 0}, {"double": 2}, {"double": 4}, {"double": 6}]}

0 comments on commit 05161d7

Please sign in to comment.