Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: in-memory caching #1207

Merged
merged 8 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/caching/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
This directory contains tutorial notebooks for the Hamilton caching feature.

- `tutorial.ipynb`: the main tutorial for caching
- `materializer_tutorial.ipynb`: tutorial on the interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial for materializer users. You should complete the `tutorial.ipynb` first.
- `in_memory_tutorial.ipynb`: How to use caching without writing metadata and results to file.
- `materializer_tutorial.ipynb`: Learn interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial.
673 changes: 673 additions & 0 deletions examples/caching/in_memory_tutorial.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions hamilton/caching/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def __getstate__(self) -> dict:
state = self.__dict__.copy()
# store the classes to reinstantiate the same backend in __setstate__
state["metadata_store_cls"] = self.metadata_store.__class__
state["metadata_store_init"] = self.metadata_store.__getstate__()
state["result_store_cls"] = self.result_store.__class__
state["result_store_init"] = self.result_store.__getstate__()
del state["metadata_store"]
del state["result_store"]
return state
Expand All @@ -288,8 +290,8 @@ def __setstate__(self, state: dict) -> None:
"""
# instantiate the backend from the class, then delete the attribute before
# setting it on the adapter instance.
self.metadata_store = state["metadata_store_cls"](path=state["_path"])
self.result_store = state["result_store_cls"](path=state["_path"])
self.metadata_store = state["metadata_store_cls"](**state["metadata_store_init"])
self.result_store = state["result_store_cls"](**state["result_store_init"])
del state["metadata_store_cls"]
del state["result_store_cls"]
self.__dict__.update(state)
Expand Down
19 changes: 18 additions & 1 deletion hamilton/caching/stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
from pathlib import Path
from typing import Any, Optional

from hamilton.caching.stores.base import ResultStore, StoredResult
try:
from typing import override
zilto marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
override = lambda x: x # noqa E731

from hamilton.io.data_adapters import DataLoader, DataSaver

from .base import ResultStore, StoredResult


class FileResultStore(ResultStore):
def __init__(self, path: str, create_dir: bool = True) -> None:
Expand All @@ -14,6 +20,12 @@ def __init__(self, path: str, create_dir: bool = True) -> None:
if self.create_dir:
self.path.mkdir(exist_ok=True, parents=True)

def __getstate__(self) -> dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __getstate__ method only serializes the path attribute. If create_dir is also important for the object's state, consider including it in the serialization.

"""Serialize the `__init__` kwargs to pass in Parallelizable branches
when using multiprocessing.
"""
return {"path": str(self.path)}

@staticmethod
def _write_result(file_path: Path, stored_result: StoredResult) -> None:
file_path.write_bytes(stored_result.save())
Expand All @@ -33,10 +45,12 @@ def _materialized_path(self, data_version: str, saver_cls: DataSaver) -> Path:
# TODO allow a more flexible mechanism to specify file path extension
return self._path_from_data_version(data_version).with_suffix(f".{saver_cls.name()}")

@override
def exists(self, data_version: str) -> bool:
result_path = self._path_from_data_version(data_version)
return result_path.exists()

@override
def set(
self,
data_version: str,
Expand Down Expand Up @@ -65,6 +79,7 @@ def set(
stored_result = StoredResult.new(value=result, saver=saver, loader=loader)
self._write_result(result_path, stored_result)

@override
def get(self, data_version: str) -> Optional[Any]:
result_path = self._path_from_data_version(data_version)
stored_result = self._load_result_from_path(result_path)
Expand All @@ -74,10 +89,12 @@ def get(self, data_version: str) -> Optional[Any]:

return stored_result.value

@override
def delete(self, data_version: str) -> None:
result_path = self._path_from_data_version(data_version)
result_path.unlink(missing_ok=True)

@override
def delete_all(self) -> None:
shutil.rmtree(self.path)
self.path.mkdir(exist_ok=True)
Expand Down
279 changes: 279 additions & 0 deletions hamilton/caching/stores/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from typing import Any, Dict, List, Optional, Sequence

try:
from typing import override
except ImportError:
override = lambda x: x # noqa E731

from hamilton.caching.cache_key import decode_key

from .base import MetadataStore, ResultStore, StoredResult
from .file import FileResultStore
from .sqlite import SQLiteMetadataStore


class InMemoryMetadataStore(MetadataStore):
zilto marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
self._data_versions: Dict[str, str] = {} # {cache_key: data_version}
self._cache_keys_by_run: Dict[str, List[str]] = {} # {run_id: [cache_key]}
self._run_ids: List[str] = []

@override
def __len__(self) -> int:
"""Number of unique ``cache_key`` values."""
return len(self._data_versions.keys())

@override
def exists(self, cache_key: str) -> bool:
"""Indicate if ``cache_key`` exists and it can retrieve a ``data_version``."""
return cache_key in self._data_versions.keys()

@override
def initialize(self, run_id: str) -> None:
"""Set up and log the beginning of the run."""
self._cache_keys_by_run[run_id] = []
self._run_ids.append(run_id)

@override
def set(self, cache_key: str, data_version: str, run_id: str, **kwargs) -> Optional[Any]:
"""Set the ``data_version`` for ``cache_key`` and associate it with the ``run_id``."""
self._data_versions[cache_key] = data_version
self._cache_keys_by_run[run_id].append(cache_key)

@override
def get(self, cache_key: str) -> Optional[str]:
"""Retrieve the ``data_version`` for ``cache_key``."""
return self._data_versions.get(cache_key, None)

@override
def delete(self, cache_key: str) -> None:
"""Delete the ``data_version`` for ``cache_key``."""
del self._data_versions[cache_key]

@override
def delete_all(self) -> None:
"""Delete all stored metadata."""
self._data_versions.clear()

def persist_to(self, metadata_store: Optional[MetadataStore] = None) -> None:
zilto marked this conversation as resolved.
Show resolved Hide resolved
"""Persist in-memory metadata using another MetadataStore implementation.

:param metadata_store: MetadataStore implementation to use for persistence.
If None, a SQLiteMetadataStore is created with the default path "./.hamilton_cache".

.. code-block:: python

from hamilton import driver
from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.caching.stores.memory import InMemoryMetadataStore
import my_dataflow

dr = (
driver.Builder()
.with_modules(my_dataflow)
.with_cache(metadata_store=InMemoryMetadataStore())
.build()
)

# execute the Driver several time. This will populate the in-memory metadata store
dr.execute(...)

# persist to disk in-memory metadata
dr.cache.metadata_store.persist_to(SQLiteMetadataStore(path="./.hamilton_cache"))

"""
if metadata_store is None:
metadata_store = SQLiteMetadataStore(path="./.hamilton_cache")

for run_id in self._run_ids:
metadata_store.initialize(run_id)

for run_id, cache_keys in self._cache_keys_by_run.items():
for cache_key in cache_keys:
data_version = self._data_versions[cache_key]
metadata_store.set(
cache_key=cache_key,
data_version=data_version,
run_id=run_id,
)

@classmethod
def load_from(cls, metadata_store: MetadataStore) -> "InMemoryMetadataStore":
"""Load in-memory metadata from another MetadataStore instance.

:param metadata_store: MetadataStore instance to load from.
:return: InMemoryMetadataStore copy of the ``metadata_store``.

.. code-block:: python

from hamilton import driver
from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.caching.stores.memory import InMemoryMetadataStore
import my_dataflow

sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache")
in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store)

# create the Driver with the in-memory metadata store
dr = (
driver.Builder()
.with_modules(my_dataflow)
.with_cache(metadata_store=in_memory_metadata_store)
.build()
)

"""
in_memory_metadata_store = InMemoryMetadataStore()

for run_id in metadata_store.get_run_ids():
in_memory_metadata_store.initialize(run_id)

for node_metadata in metadata_store.get_run(run_id):
in_memory_metadata_store.set(
cache_key=node_metadata["cache_key"],
data_version=node_metadata["data_version"],
run_id=run_id,
)

return in_memory_metadata_store

@override
def get_run_ids(self) -> List[str]:
"""Return a list of all ``run_id`` values stored."""
return self._run_ids

@override
def get_run(self, run_id: str) -> List[Dict[str, str]]:
"""Return a list of node metadata associated with a run."""
if self._cache_keys_by_run.get(run_id, None) is None:
raise IndexError(f"Run ID not found: {run_id}")

nodes_metadata = []
for cache_key in self._cache_keys_by_run[run_id]:
decoded_key = decode_key(cache_key)
nodes_metadata.append(
dict(
cache_key=cache_key,
data_version=self._data_versions[cache_key],
node_name=decoded_key["node_name"],
code_version=decoded_key["code_version"],
dependencies_data_versions=decoded_key["dependencies_data_versions"],
)
)

return nodes_metadata


class InMemoryResultStore(ResultStore):
def __init__(self, persist_on_exit: bool = False) -> None:
self._results: Dict[str, StoredResult] = {} # {data_version: result}

@override
def exists(self, data_version: str) -> bool:
return data_version in self._results.keys()

# TODO handle materialization
@override
def set(self, data_version: str, result: Any, **kwargs) -> None:
self._results[data_version] = StoredResult.new(value=result)

@override
def get(self, data_version: str) -> Optional[Any]:
stored_result = self._results.get(data_version, None)
if stored_result is None:
return None

return stored_result.value

@override
def delete(self, data_version: str) -> None:
del self._results[data_version]

@override
def delete_all(self) -> None:
self._results.clear()

def delete_expired(self) -> None:
to_delete = [
data_version
for data_version, stored_result in self._results.items()
if stored_result.expired
]

# first collect keys then delete because you can delete from dictionary
# as you iterate through it
for data_version in to_delete:
self.delete(data_version)

def persist_to(self, result_store: Optional[ResultStore] = None) -> None:
"""Persist in-memory results using another ``ResultStore`` implementation.

:param result_store: ResultStore implementation to use for persistence.
If None, a FileResultStore is created with the default path "./.hamilton_cache".
"""
if result_store is None:
result_store = FileResultStore(path="./.hamilton_cache")

for data_version, stored_result in self._results.items():
result_store.set(data_version, stored_result.value)

@classmethod
def load_from(
cls,
result_store: ResultStore,
metadata_store: Optional[MetadataStore] = None,
data_versions: Optional[Sequence[str]] = None,
) -> "InMemoryResultStore":
"""Load in-memory results from another ResultStore instance.

Since result stores do not store an index of their keys, you must provide a
``MetadataStore`` instance or a list of ``data_version`` for which results
should be loaded in memory.

:param result_store: ``ResultStore`` instance to load results from.
:param metadata_store: ``MetadataStore`` instance from which all ``data_version`` are retrieved.
:return: InMemoryResultStore copy of the ``result_store``.

.. code-block:: python

from hamilton import driver
from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.caching.stores.memory import InMemoryMetadataStore
import my_dataflow

sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache")
in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store)

# create the Driver with the in-memory metadata store
dr = (
driver.Builder()
.with_modules(my_dataflow)
.with_cache(metadata_store=in_memory_metadata_store)
.build()
)


"""
if metadata_store is None and data_versions is None:
raise ValueError(
"A `metadata_store` or `data_versions` must be provided to load results."
)

in_memory_result_store = InMemoryResultStore()

data_versions_to_retrieve = set()
if data_versions is not None:
data_versions_to_retrieve.update(data_versions)

if metadata_store is not None:
for run_id in metadata_store.get_run_ids():
for node_metadata in metadata_store.get_run(run_id):
data_versions_to_retrieve.add(node_metadata["data_version"])

for data_version in data_versions_to_retrieve:
# TODO disambiguate "result is None" from the sentinel value when `data_version`
# is not found in `result_store`.
result = result_store.get(data_version)
in_memory_result_store.set(data_version, result)

return in_memory_result_store
Loading