Skip to content

Commit

Permalink
refactored to support Parallel/Collect
Browse files Browse the repository at this point in the history
  • Loading branch information
zilto authored and zilto committed Oct 28, 2024
1 parent 4273aba commit e1a7b00
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 44 deletions.
6 changes: 4 additions & 2 deletions hamilton/caching/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def __getstate__(self) -> dict:
state = self.__dict__.copy()
# store the classes to reinstantiate the same backend in __setstate__
state["metadata_store_cls"] = self.metadata_store.__class__
state["metadata_store_init"] = self.metadata_store.__getstate__()
state["result_store_cls"] = self.result_store.__class__
state["result_store_init"] = self.result_store.__getstate__()
del state["metadata_store"]
del state["result_store"]
return state
Expand All @@ -288,8 +290,8 @@ def __setstate__(self, state: dict) -> None:
"""
# instantiate the backend from the class, then delete the attribute before
# setting it on the adapter instance.
self.metadata_store = state["metadata_store_cls"](path=state["_path"])
self.result_store = state["result_store_cls"](path=state["_path"])
self.metadata_store = state["metadata_store_cls"](**state["metadata_store_init"])
self.result_store = state["result_store_cls"](**state["result_store_init"])
del state["metadata_store_cls"]
del state["result_store_cls"]
self.__dict__.update(state)
Expand Down
14 changes: 12 additions & 2 deletions hamilton/caching/stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def __init__(
# `sqlite3.OperationalError` because tables are missing.
self._create_tables_if_not_exists()

def __getstate__(self) -> dict:
"""Serialized `__init__` arguments required to initialize the
MetadataStore in a new thread or process.
"""
state = {}
# NOTE kwarg `path` is not equivalent to `self._path`
state["path"] = self._directory
state["connection_kwargs"] = self.connection_kwargs
return state

def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._thread_local, "connection"):
self._thread_local.connection = sqlite3.connect(
Expand Down Expand Up @@ -127,8 +137,8 @@ def set(
except BaseException as e:
raise ValueError(
f"Failed decoding the cache_key: {cache_key}.\n",
"Was it manually created? Do `code_version` and `data_version` found in ",
"``dependencies_data_versions`` have the proper encoding?",
"The `cache_key` must be created by `hamilton.caching.cache_key.create_cache_key()` ",
"if `node_name` and `code_version` are not provided.",
) from e

node_name = decoded_key["node_name"]
Expand Down
2 changes: 0 additions & 2 deletions tests/caching/metadata_store/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def metadata_store(request, tmp_path):

yield metadata_store

metadata_store.delete_all()


@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
def test_initialize_empty(metadata_store):
Expand Down
28 changes: 28 additions & 0 deletions tests/caching/result_store/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest

from hamilton.caching.stores.file import FileResultStore
from hamilton.caching.stores.memory import InMemoryResultStore


def _instantiate_result_store(result_store_cls, tmp_path):
if result_store_cls == FileResultStore:
return FileResultStore(path=tmp_path)
elif result_store_cls == InMemoryResultStore:
return InMemoryResultStore()
else:
raise ValueError(
f"Class `{result_store_cls}` isn't defined in `_instantiate_metadata_store()`"
)


@pytest.fixture
def result_store(request, tmp_path):
result_store_cls = request.param
result_store = _instantiate_result_store(result_store_cls, tmp_path)

yield result_store

result_store.delete_all()


# NOTE add tests that check properties shared across result store implementations below
31 changes: 7 additions & 24 deletions tests/caching/result_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,13 @@
from hamilton.caching.stores.file import FileResultStore
from hamilton.caching.stores.memory import InMemoryResultStore

# `result_store` is imported but not directly used because it's
# a pytest fixture automatically provided to tests
from .test_base import result_store # noqa: F401

def _instantiate_result_store(result_store_cls, tmp_path):
if result_store_cls == FileResultStore:
return FileResultStore(path=tmp_path)
elif result_store_cls == InMemoryResultStore:
return InMemoryResultStore()
else:
raise ValueError(
f"Class `{result_store_cls}` isn't defined in `_instantiate_metadata_store()`"
)


@pytest.fixture
def result_store(request, tmp_path):
result_store_cls = request.param
result_store = _instantiate_result_store(result_store_cls, tmp_path)

yield result_store

result_store.delete_all()


def _store_size(result_store: InMemoryResultStore) -> int:
return len(result_store._results)
def _store_size(memory_store: InMemoryResultStore) -> int:
return len(memory_store._results)


@pytest.fixture
Expand Down Expand Up @@ -148,6 +131,6 @@ def test_load_from(result_store): # noqa: F811


def test_load_from_must_have_metadata_store_or_data_versions(tmp_path):
result_store = FileResultStore(tmp_path)
file_result_store = FileResultStore(tmp_path)
with pytest.raises(ValueError):
InMemoryResultStore.load_from(result_store)
InMemoryResultStore.load_from(result_store=file_result_store)
65 changes: 51 additions & 14 deletions tests/caching/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,45 @@
import itertools
from typing import List

import pandas as pd
import pytest

from hamilton import ad_hoc_utils, driver
from hamilton.caching.adapter import CachingEventType, HamiltonCacheAdapter
from hamilton.caching.stores.file import FileResultStore
from hamilton.caching.stores.memory import InMemoryMetadataStore, InMemoryResultStore
from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.execution.executors import (
MultiProcessingExecutor,
MultiThreadingExecutor,
SynchronousLocalTaskExecutor,
)
from hamilton.function_modifiers import cache as cache_decorator

# `metadata_store` and `result_store` are imported but not directly used because they
# are pytest fixtures automatically provided to tests
from .metadata_store.test_base import metadata_store # noqa: F401
from .result_store.test_base import result_store # noqa: F401
from tests.resources.dynamic_parallelism import parallel_linear_basic, parallelism_with_caching


def _instantiate_executor(executor_cls):
if executor_cls == SynchronousLocalTaskExecutor:
return SynchronousLocalTaskExecutor()
elif executor_cls == MultiProcessingExecutor:
return MultiProcessingExecutor(max_tasks=10)
elif executor_cls == MultiThreadingExecutor:
return MultiThreadingExecutor(max_tasks=10)
else:
raise ValueError(f"Class `{executor_cls}` isn't defined in `_instantiate_executor()`")


@pytest.fixture
def executor(request):
executor_cls = request.param
return _instantiate_executor(executor_cls)


@pytest.fixture
def dr(request, tmp_path):
module = request.param
Expand Down Expand Up @@ -483,19 +508,34 @@ def foo() -> dict:
assert result[node_name] == retrieved_result


EXECUTORS_AND_STORES_CONFIGURATIONS = list(
itertools.product(
[SynchronousLocalTaskExecutor, MultiThreadingExecutor, MultiProcessingExecutor],
[SQLiteMetadataStore],
[FileResultStore],
)
)

# InMemory stores can't be used with multiprocessing because they don't share memory.
IN_MEMORY_CONFIGURATIONS = list(
itertools.product(
[SynchronousLocalTaskExecutor, MultiThreadingExecutor],
[InMemoryMetadataStore, SQLiteMetadataStore],
[InMemoryResultStore, FileResultStore],
)
)

EXECUTORS_AND_STORES_CONFIGURATIONS += IN_MEMORY_CONFIGURATIONS


@pytest.mark.parametrize(
"executor",
[
SynchronousLocalTaskExecutor(),
MultiProcessingExecutor(max_tasks=10),
MultiThreadingExecutor(max_tasks=10),
],
"executor,metadata_store,result_store", EXECUTORS_AND_STORES_CONFIGURATIONS, indirect=True
)
def test_parallel_synchronous_step_by_step(tmp_path, executor):
def test_parallel_synchronous_step_by_step(executor, metadata_store, result_store): # noqa: F811
dr = (
driver.Builder()
.with_modules(parallel_linear_basic)
.with_cache(path=tmp_path)
.with_cache(metadata_store=metadata_store, result_store=result_store)
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executor)
.build()
Expand Down Expand Up @@ -559,11 +599,8 @@ def test_parallel_synchronous_step_by_step(tmp_path, executor):

@pytest.mark.parametrize(
"executor",
[
SynchronousLocalTaskExecutor(),
MultiProcessingExecutor(max_tasks=10),
MultiThreadingExecutor(max_tasks=10),
],
[SynchronousLocalTaskExecutor, MultiProcessingExecutor, MultiThreadingExecutor],
indirect=True,
)
def test_materialize_parallel_branches(tmp_path, executor):
# NOTE the module can't be defined here because multithreading requires functions to be top-level.
Expand Down Expand Up @@ -596,7 +633,7 @@ def test_materialize_parallel_branches(tmp_path, executor):
)


def test_consistent_cache_key_with_or_without_defaut_parameter(tmp_path):
def test_consistent_cache_key_with_or_without_default_parameter(tmp_path):
def foo(external_dep: int = 3) -> int:
return external_dep + 1

Expand Down

0 comments on commit e1a7b00

Please sign in to comment.