diff --git a/examples/caching/README.md b/examples/caching/README.md
index 4b53b2218..82a038bf3 100644
--- a/examples/caching/README.md
+++ b/examples/caching/README.md
@@ -3,4 +3,5 @@
This directory contains tutorial notebooks for the Hamilton caching feature.
- `tutorial.ipynb`: the main tutorial for caching
-- `materializer_tutorial.ipynb`: tutorial on the interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial for materializer users. You should complete the `tutorial.ipynb` first.
+- `in_memory_tutorial.ipynb`: How to use caching without writing metadata and results to file.
+- `materializer_tutorial.ipynb`: Learn interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial.
diff --git a/examples/caching/in_memory_tutorial.ipynb b/examples/caching/in_memory_tutorial.ipynb
new file mode 100644
index 000000000..712422a22
--- /dev/null
+++ b/examples/caching/in_memory_tutorial.ipynb
@@ -0,0 +1,673 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# In-memory caching tutorial\n",
+ "\n",
+ "This notebook shows how to use in-memory caching, which allows to cache results between runs without writing to disk. This uses the `InMemoryResultStore` and `InMemoryMetadataStore` classes.\n",
+ "\n",
+ "> ⛔ In-memory caching can consume a lot of memory if you're using storing large results. Selectively caching nodes is recommended.\n",
+ "\n",
+ "If you're new to caching, you should take a look at the [caching tutorial](./tutorial.ipynb) first!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "Throughout this tutorial, we'll be using the Hamilton notebook extension to define dataflows directly in the notebook ([see tutorial](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/jupyter_notebook_magic/example.ipynb)).\n",
+ "\n",
+ "Then, we get the logger for caching and clear previously cached results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "import shutil\n",
+ "\n",
+ "# avoid loading all available plugins for fast startup time\n",
+ "from hamilton import registry\n",
+ "registry.disable_autoload()\n",
+ "registry.load_extension(\"pandas\")\n",
+ "\n",
+ "from hamilton import driver\n",
+ "\n",
+ "# load the notebook extension\n",
+ "%reload_ext hamilton.plugins.jupyter_magic\n",
+ "\n",
+ "logger = logging.getLogger(\"hamilton.caching\")\n",
+ "logger.setLevel(logging.INFO)\n",
+ "logger.addHandler(logging.StreamHandler())\n",
+ "\n",
+ "shutil.rmtree(\"./.hamilton_cache\", ignore_errors=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define a dataflow\n",
+ "We define a simple dataflow that loads a dataframe of transactions, filters by date, converts currency to USD, and sums the amount per country."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "%%cell_to_module dataflow_module --display\n",
+ "import pandas as pd\n",
+ "\n",
+ "DATA = {\n",
+ " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n",
+ " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n",
+ " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n",
+ " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n",
+ " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n",
+ "}\n",
+ "\n",
+ "def raw_data() -> pd.DataFrame:\n",
+ " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n",
+ " return pd.DataFrame(DATA)\n",
+ "\n",
+ "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n",
+ " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n",
+ " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n",
+ " df[\"amound_in_usd\"] = df[\"amount\"]\n",
+ " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n",
+ " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18 # <- LINE ADDED\n",
+ " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05 # <- LINE ADDED\n",
+ " return df\n",
+ "\n",
+ "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n",
+ " \"\"\"Sum the amount in USD per country\"\"\"\n",
+ " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## In-memory caching\n",
+ "To use in-memory caching, pass `InMemoryResultStore` and `InMemoryMetadataStore` to `Builder().with_cache()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from hamilton.caching.stores.memory import InMemoryMetadataStore, InMemoryResultStore\n",
+ "\n",
+ "dr = (\n",
+ " driver.Builder()\n",
+ " .with_modules(dataflow_module)\n",
+ " .with_cache(\n",
+ " result_store=InMemoryResultStore(),\n",
+ " metadata_store=InMemoryMetadataStore(),\n",
+ " )\n",
+ " .build()\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Execution 1\n",
+ "For execution 1, we see that all nodes are executed."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "raw_data::adapter::execute_node\n",
+ "processed_data::adapter::execute_node\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " cities date amount country currency amound_in_usd\n",
+ "0 New York 2024-09-13 478.23 USA USD 478.2300\n",
+ "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n",
+ "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n",
+ "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n",
+ "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n"
+ ]
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "results = dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n",
+ "print()\n",
+ "print(results[\"processed_data\"].head())\n",
+ "dr.cache.view_run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Execution 2\n",
+ "For execution 2, we see that all nodes are retrieved from cache."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "raw_data::result_store::get_result::hit\n",
+ "processed_data::result_store::get_result::hit\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ " cities date amount country currency amound_in_usd\n",
+ "0 New York 2024-09-13 478.23 USA USD 478.2300\n",
+ "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n",
+ "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n",
+ "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n",
+ "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n"
+ ]
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "results = dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n",
+ "print()\n",
+ "print(results[\"processed_data\"].head())\n",
+ "dr.cache.view_run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Persisting in-memory data\n",
+ "\n",
+ "Now, we import `SQLiteMetadataStore` and `FileResultStore` to persist the data to disk. We access the in-memory stores via `dr.cache.result_store` and `dr.cache.metadata_store` and call the `.persist_to()` method on each.\n",
+ "\n",
+ "After executing the cell, you should see a new directory `./.persisted_cache` with results and metadata."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from hamilton.caching.stores.sqlite import SQLiteMetadataStore\n",
+ "from hamilton.caching.stores.file import FileResultStore\n",
+ "\n",
+ "path = \"./.persisted_cache\"\n",
+ "on_disk_results = FileResultStore(path=path)\n",
+ "on_disk_metadata = SQLiteMetadataStore(path=path)\n",
+ "\n",
+ "dr.cache.result_store.persist_to(on_disk_results)\n",
+ "dr.cache.metadata_store.persist_to(on_disk_metadata)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Loading persisted data\n",
+ "\n",
+ "Now, we create a new `Driver`. Instead of starting with empty in-memory stores, we will load the previously persisted results by calling `.load_from()` on the `InMemoryResultStore` and `InMemoryMetadataStore` classes.\n",
+ "\n",
+ "For `InMemoryResultStore.load_from()`, we must provide a `MetadataStore` or a list of `data_version` to load results for."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dr = (\n",
+ " driver.Builder()\n",
+ " .with_modules(dataflow_module)\n",
+ " .with_cache(\n",
+ " result_store=InMemoryResultStore.load_from(\n",
+ " on_disk_results,\n",
+ " metadata_store=on_disk_metadata,\n",
+ " ),\n",
+ " metadata_store=InMemoryMetadataStore.load_from(on_disk_metadata),\n",
+ " )\n",
+ " .build()\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We print the size of the metadata store to show it contains 2 entries (one for `raw_data` and another for `processed_data`). Also, we see that results load from `FileResultStore`are successfully retrieved from the in-memory stores."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "raw_data::result_store::get_result::hit\n",
+ "processed_data::result_store::get_result::hit\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2\n",
+ "\n",
+ " cities date amount country currency amound_in_usd\n",
+ "0 New York 2024-09-13 478.23 USA USD 478.2300\n",
+ "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n",
+ "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n",
+ "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n",
+ "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n"
+ ]
+ },
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(dr.cache.metadata_store.size)\n",
+ "\n",
+ "results = dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n",
+ "print()\n",
+ "print(results[\"processed_data\"].head())\n",
+ "dr.cache.view_run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Use cases\n",
+ "\n",
+ "In-memory caching can be useful when you're doing a lot of experimentation in a notebook or an interactive session and don't want to persist results for future use. \n",
+ "\n",
+ "It can also speed up execution in some cases because you're no longer doing read/write to disk for each node execution."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/hamilton/caching/adapter.py b/hamilton/caching/adapter.py
index c41bfa4fd..59fd18139 100644
--- a/hamilton/caching/adapter.py
+++ b/hamilton/caching/adapter.py
@@ -274,7 +274,9 @@ def __getstate__(self) -> dict:
state = self.__dict__.copy()
# store the classes to reinstantiate the same backend in __setstate__
state["metadata_store_cls"] = self.metadata_store.__class__
+ state["metadata_store_init"] = self.metadata_store.__getstate__()
state["result_store_cls"] = self.result_store.__class__
+ state["result_store_init"] = self.result_store.__getstate__()
del state["metadata_store"]
del state["result_store"]
return state
@@ -288,8 +290,8 @@ def __setstate__(self, state: dict) -> None:
"""
# instantiate the backend from the class, then delete the attribute before
# setting it on the adapter instance.
- self.metadata_store = state["metadata_store_cls"](path=state["_path"])
- self.result_store = state["result_store_cls"](path=state["_path"])
+ self.metadata_store = state["metadata_store_cls"](**state["metadata_store_init"])
+ self.result_store = state["result_store_cls"](**state["result_store_init"])
del state["metadata_store_cls"]
del state["result_store_cls"]
self.__dict__.update(state)
diff --git a/hamilton/caching/stores/file.py b/hamilton/caching/stores/file.py
index d483ca292..34144d9c1 100644
--- a/hamilton/caching/stores/file.py
+++ b/hamilton/caching/stores/file.py
@@ -2,9 +2,15 @@
from pathlib import Path
from typing import Any, Optional
-from hamilton.caching.stores.base import ResultStore, StoredResult
+try:
+ from typing import override
+except ImportError:
+ override = lambda x: x # noqa E731
+
from hamilton.io.data_adapters import DataLoader, DataSaver
+from .base import ResultStore, StoredResult
+
class FileResultStore(ResultStore):
def __init__(self, path: str, create_dir: bool = True) -> None:
@@ -14,6 +20,12 @@ def __init__(self, path: str, create_dir: bool = True) -> None:
if self.create_dir:
self.path.mkdir(exist_ok=True, parents=True)
+ def __getstate__(self) -> dict:
+ """Serialize the `__init__` kwargs to pass in Parallelizable branches
+ when using multiprocessing.
+ """
+ return {"path": str(self.path)}
+
@staticmethod
def _write_result(file_path: Path, stored_result: StoredResult) -> None:
file_path.write_bytes(stored_result.save())
@@ -33,10 +45,12 @@ def _materialized_path(self, data_version: str, saver_cls: DataSaver) -> Path:
# TODO allow a more flexible mechanism to specify file path extension
return self._path_from_data_version(data_version).with_suffix(f".{saver_cls.name()}")
+ @override
def exists(self, data_version: str) -> bool:
result_path = self._path_from_data_version(data_version)
return result_path.exists()
+ @override
def set(
self,
data_version: str,
@@ -65,6 +79,7 @@ def set(
stored_result = StoredResult.new(value=result, saver=saver, loader=loader)
self._write_result(result_path, stored_result)
+ @override
def get(self, data_version: str) -> Optional[Any]:
result_path = self._path_from_data_version(data_version)
stored_result = self._load_result_from_path(result_path)
@@ -74,10 +89,12 @@ def get(self, data_version: str) -> Optional[Any]:
return stored_result.value
+ @override
def delete(self, data_version: str) -> None:
result_path = self._path_from_data_version(data_version)
result_path.unlink(missing_ok=True)
+ @override
def delete_all(self) -> None:
shutil.rmtree(self.path)
self.path.mkdir(exist_ok=True)
diff --git a/hamilton/caching/stores/memory.py b/hamilton/caching/stores/memory.py
new file mode 100644
index 000000000..37eee89d5
--- /dev/null
+++ b/hamilton/caching/stores/memory.py
@@ -0,0 +1,279 @@
+from typing import Any, Dict, List, Optional, Sequence
+
+try:
+ from typing import override
+except ImportError:
+ override = lambda x: x # noqa E731
+
+from hamilton.caching.cache_key import decode_key
+
+from .base import MetadataStore, ResultStore, StoredResult
+from .file import FileResultStore
+from .sqlite import SQLiteMetadataStore
+
+
+class InMemoryMetadataStore(MetadataStore):
+ def __init__(self) -> None:
+ self._data_versions: Dict[str, str] = {} # {cache_key: data_version}
+ self._cache_keys_by_run: Dict[str, List[str]] = {} # {run_id: [cache_key]}
+ self._run_ids: List[str] = []
+
+ @override
+ def __len__(self) -> int:
+ """Number of unique ``cache_key`` values."""
+ return len(self._data_versions.keys())
+
+ @override
+ def exists(self, cache_key: str) -> bool:
+ """Indicate if ``cache_key`` exists and it can retrieve a ``data_version``."""
+ return cache_key in self._data_versions.keys()
+
+ @override
+ def initialize(self, run_id: str) -> None:
+ """Set up and log the beginning of the run."""
+ self._cache_keys_by_run[run_id] = []
+ self._run_ids.append(run_id)
+
+ @override
+ def set(self, cache_key: str, data_version: str, run_id: str, **kwargs) -> Optional[Any]:
+ """Set the ``data_version`` for ``cache_key`` and associate it with the ``run_id``."""
+ self._data_versions[cache_key] = data_version
+ self._cache_keys_by_run[run_id].append(cache_key)
+
+ @override
+ def get(self, cache_key: str) -> Optional[str]:
+ """Retrieve the ``data_version`` for ``cache_key``."""
+ return self._data_versions.get(cache_key, None)
+
+ @override
+ def delete(self, cache_key: str) -> None:
+ """Delete the ``data_version`` for ``cache_key``."""
+ del self._data_versions[cache_key]
+
+ @override
+ def delete_all(self) -> None:
+ """Delete all stored metadata."""
+ self._data_versions.clear()
+
+ def persist_to(self, metadata_store: Optional[MetadataStore] = None) -> None:
+ """Persist in-memory metadata using another MetadataStore implementation.
+
+ :param metadata_store: MetadataStore implementation to use for persistence.
+ If None, a SQLiteMetadataStore is created with the default path "./.hamilton_cache".
+
+ .. code-block:: python
+
+ from hamilton import driver
+ from hamilton.caching.stores.sqlite import SQLiteMetadataStore
+ from hamilton.caching.stores.memory import InMemoryMetadataStore
+ import my_dataflow
+
+ dr = (
+ driver.Builder()
+ .with_modules(my_dataflow)
+ .with_cache(metadata_store=InMemoryMetadataStore())
+ .build()
+ )
+
+ # execute the Driver several time. This will populate the in-memory metadata store
+ dr.execute(...)
+
+ # persist to disk in-memory metadata
+ dr.cache.metadata_store.persist_to(SQLiteMetadataStore(path="./.hamilton_cache"))
+
+ """
+ if metadata_store is None:
+ metadata_store = SQLiteMetadataStore(path="./.hamilton_cache")
+
+ for run_id in self._run_ids:
+ metadata_store.initialize(run_id)
+
+ for run_id, cache_keys in self._cache_keys_by_run.items():
+ for cache_key in cache_keys:
+ data_version = self._data_versions[cache_key]
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+
+ @classmethod
+ def load_from(cls, metadata_store: MetadataStore) -> "InMemoryMetadataStore":
+ """Load in-memory metadata from another MetadataStore instance.
+
+ :param metadata_store: MetadataStore instance to load from.
+ :return: InMemoryMetadataStore copy of the ``metadata_store``.
+
+ .. code-block:: python
+
+ from hamilton import driver
+ from hamilton.caching.stores.sqlite import SQLiteMetadataStore
+ from hamilton.caching.stores.memory import InMemoryMetadataStore
+ import my_dataflow
+
+ sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache")
+ in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store)
+
+ # create the Driver with the in-memory metadata store
+ dr = (
+ driver.Builder()
+ .with_modules(my_dataflow)
+ .with_cache(metadata_store=in_memory_metadata_store)
+ .build()
+ )
+
+ """
+ in_memory_metadata_store = InMemoryMetadataStore()
+
+ for run_id in metadata_store.get_run_ids():
+ in_memory_metadata_store.initialize(run_id)
+
+ for node_metadata in metadata_store.get_run(run_id):
+ in_memory_metadata_store.set(
+ cache_key=node_metadata["cache_key"],
+ data_version=node_metadata["data_version"],
+ run_id=run_id,
+ )
+
+ return in_memory_metadata_store
+
+ @override
+ def get_run_ids(self) -> List[str]:
+ """Return a list of all ``run_id`` values stored."""
+ return self._run_ids
+
+ @override
+ def get_run(self, run_id: str) -> List[Dict[str, str]]:
+ """Return a list of node metadata associated with a run."""
+ if self._cache_keys_by_run.get(run_id, None) is None:
+ raise IndexError(f"Run ID not found: {run_id}")
+
+ nodes_metadata = []
+ for cache_key in self._cache_keys_by_run[run_id]:
+ decoded_key = decode_key(cache_key)
+ nodes_metadata.append(
+ dict(
+ cache_key=cache_key,
+ data_version=self._data_versions[cache_key],
+ node_name=decoded_key["node_name"],
+ code_version=decoded_key["code_version"],
+ dependencies_data_versions=decoded_key["dependencies_data_versions"],
+ )
+ )
+
+ return nodes_metadata
+
+
+class InMemoryResultStore(ResultStore):
+ def __init__(self, persist_on_exit: bool = False) -> None:
+ self._results: Dict[str, StoredResult] = {} # {data_version: result}
+
+ @override
+ def exists(self, data_version: str) -> bool:
+ return data_version in self._results.keys()
+
+ # TODO handle materialization
+ @override
+ def set(self, data_version: str, result: Any, **kwargs) -> None:
+ self._results[data_version] = StoredResult.new(value=result)
+
+ @override
+ def get(self, data_version: str) -> Optional[Any]:
+ stored_result = self._results.get(data_version, None)
+ if stored_result is None:
+ return None
+
+ return stored_result.value
+
+ @override
+ def delete(self, data_version: str) -> None:
+ del self._results[data_version]
+
+ @override
+ def delete_all(self) -> None:
+ self._results.clear()
+
+ def delete_expired(self) -> None:
+ to_delete = [
+ data_version
+ for data_version, stored_result in self._results.items()
+ if stored_result.expired
+ ]
+
+ # first collect keys then delete because you can delete from dictionary
+ # as you iterate through it
+ for data_version in to_delete:
+ self.delete(data_version)
+
+ def persist_to(self, result_store: Optional[ResultStore] = None) -> None:
+ """Persist in-memory results using another ``ResultStore`` implementation.
+
+ :param result_store: ResultStore implementation to use for persistence.
+ If None, a FileResultStore is created with the default path "./.hamilton_cache".
+ """
+ if result_store is None:
+ result_store = FileResultStore(path="./.hamilton_cache")
+
+ for data_version, stored_result in self._results.items():
+ result_store.set(data_version, stored_result.value)
+
+ @classmethod
+ def load_from(
+ cls,
+ result_store: ResultStore,
+ metadata_store: Optional[MetadataStore] = None,
+ data_versions: Optional[Sequence[str]] = None,
+ ) -> "InMemoryResultStore":
+ """Load in-memory results from another ResultStore instance.
+
+ Since result stores do not store an index of their keys, you must provide a
+ ``MetadataStore`` instance or a list of ``data_version`` for which results
+ should be loaded in memory.
+
+ :param result_store: ``ResultStore`` instance to load results from.
+ :param metadata_store: ``MetadataStore`` instance from which all ``data_version`` are retrieved.
+ :return: InMemoryResultStore copy of the ``result_store``.
+
+ .. code-block:: python
+
+ from hamilton import driver
+ from hamilton.caching.stores.sqlite import SQLiteMetadataStore
+ from hamilton.caching.stores.memory import InMemoryMetadataStore
+ import my_dataflow
+
+ sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache")
+ in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store)
+
+ # create the Driver with the in-memory metadata store
+ dr = (
+ driver.Builder()
+ .with_modules(my_dataflow)
+ .with_cache(metadata_store=in_memory_metadata_store)
+ .build()
+ )
+
+
+ """
+ if metadata_store is None and data_versions is None:
+ raise ValueError(
+ "A `metadata_store` or `data_versions` must be provided to load results."
+ )
+
+ in_memory_result_store = InMemoryResultStore()
+
+ data_versions_to_retrieve = set()
+ if data_versions is not None:
+ data_versions_to_retrieve.update(data_versions)
+
+ if metadata_store is not None:
+ for run_id in metadata_store.get_run_ids():
+ for node_metadata in metadata_store.get_run(run_id):
+ data_versions_to_retrieve.add(node_metadata["data_version"])
+
+ for data_version in data_versions_to_retrieve:
+ # TODO disambiguate "result is None" from the sentinel value when `data_version`
+ # is not found in `result_store`.
+ result = result_store.get(data_version)
+ in_memory_result_store.set(data_version, result)
+
+ return in_memory_result_store
diff --git a/hamilton/caching/stores/sqlite.py b/hamilton/caching/stores/sqlite.py
index 4ad0eb480..1e434844b 100644
--- a/hamilton/caching/stores/sqlite.py
+++ b/hamilton/caching/stores/sqlite.py
@@ -20,6 +20,20 @@ def __init__(
self._thread_local = threading.local()
+ # creating tables at `__init__` prevents other methods from encountering
+ # `sqlite3.OperationalError` because tables are missing.
+ self._create_tables_if_not_exists()
+
+ def __getstate__(self) -> dict:
+ """Serialized `__init__` arguments required to initialize the
+ MetadataStore in a new thread or process.
+ """
+ state = {}
+ # NOTE kwarg `path` is not equivalent to `self._path`
+ state["path"] = self._directory
+ state["connection_kwargs"] = self.connection_kwargs
+ return state
+
def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._thread_local, "connection"):
self._thread_local.connection = sqlite3.connect(
@@ -33,14 +47,15 @@ def _close_connection(self) -> None:
del self._thread_local.connection
@property
- def connection(self):
+ def connection(self) -> sqlite3.Connection:
+ """Connection to the SQLite database."""
return self._get_connection()
def __del__(self):
"""Close the SQLite connection when the object is deleted"""
self._close_connection()
- def _create_tables_if_not_exists(self):
+ def _create_tables_if_not_exists(self) -> None:
"""Create the tables necessary for the cache:
run_ids: queue of run_ids, ordered by start time.
@@ -92,12 +107,11 @@ def initialize(self, run_id) -> None:
"""Call initialize when starting a run. This will create database tables
if necessary.
"""
- self._create_tables_if_not_exists()
cur = self.connection.cursor()
cur.execute("INSERT INTO run_ids (run_id) VALUES (?)", (run_id,))
self.connection.commit()
- def __len__(self):
+ def __len__(self) -> int:
"""Number of entries in cache_metadata"""
cur = self.connection.cursor()
cur.execute("SELECT COUNT(*) FROM cache_metadata")
@@ -118,7 +132,15 @@ def set(
# if the caller of ``.set()`` directly provides the ``node_name`` and ``code_version``,
# we can skip the decoding step.
if (node_name is None) or (code_version is None):
- decoded_key = decode_key(cache_key)
+ try:
+ decoded_key = decode_key(cache_key)
+ except BaseException as e:
+ raise ValueError(
+ f"Failed decoding the cache_key: {cache_key}.\n",
+ "The `cache_key` must be created by `hamilton.caching.cache_key.create_cache_key()` ",
+ "if `node_name` and `code_version` are not provided.",
+ ) from e
+
node_name = decoded_key["node_name"]
code_version = decoded_key["code_version"]
diff --git a/hamilton/caching/stores/utils.py b/hamilton/caching/stores/utils.py
index 1a69b41eb..e4df5203f 100644
--- a/hamilton/caching/stores/utils.py
+++ b/hamilton/caching/stores/utils.py
@@ -2,6 +2,7 @@
def get_directory_size(directory: str) -> float:
+ """Get the size of the content of a directory in bytes."""
total_size = 0
for p in pathlib.Path(directory).rglob("*"):
if p.is_file():
@@ -11,6 +12,7 @@ def get_directory_size(directory: str) -> float:
def readable_bytes_size(n_bytes: float) -> str:
+ """Convert a number of bytes to a human-readable unit."""
labels = ["B", "KB", "MB", "GB", "TB"]
exponent = 0
diff --git a/tests/caching/metadata_store/__init__.py b/tests/caching/metadata_store/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/caching/metadata_store/test_base.py b/tests/caching/metadata_store/test_base.py
new file mode 100644
index 000000000..0e6078bb6
--- /dev/null
+++ b/tests/caching/metadata_store/test_base.py
@@ -0,0 +1,173 @@
+from typing import Dict
+
+import pytest
+
+from hamilton.caching.cache_key import create_cache_key
+from hamilton.caching.fingerprinting import hash_value
+from hamilton.caching.stores.memory import InMemoryMetadataStore
+from hamilton.caching.stores.sqlite import SQLiteMetadataStore
+from hamilton.graph_types import hash_source_code
+
+# if you're adding a new `MetadataStore` implementation, add it to this list.
+# Your implementation should successfully pass all the tests.
+# Implementation-specific tests should be added to `test_{implementation}.py`.
+IMPLEMENTATIONS = [SQLiteMetadataStore, InMemoryMetadataStore]
+
+
+def _instantiate_metadata_store(metadata_store_cls, tmp_path):
+ if metadata_store_cls == SQLiteMetadataStore:
+ return SQLiteMetadataStore(path=tmp_path)
+ elif metadata_store_cls == InMemoryMetadataStore:
+ return InMemoryMetadataStore()
+ else:
+ raise ValueError(
+ f"Class `{metadata_store_cls}` isn't defined in `_instantiate_metadata_store()`"
+ )
+
+
+def _mock_cache_key(
+ node_name: str = "foo",
+ code_version: str = "FOO-1",
+ dependencies_data_versions: Dict[str, str] = None,
+) -> str:
+ """Utility to create a valid cache key from mock values.
+ This is helpful because ``code_version`` and ``data_version`` found in ``dependencies_data_versions``
+ must respect specific encoding.
+ """
+ dependencies_data_versions = (
+ dependencies_data_versions if dependencies_data_versions is not None else {}
+ )
+ return create_cache_key(
+ node_name=node_name,
+ code_version=hash_source_code(code_version),
+ dependencies_data_versions={k: hash_value(v) for k, v in dependencies_data_versions},
+ )
+
+
+@pytest.fixture
+def metadata_store(request, tmp_path):
+ metadata_store_cls = request.param
+ metadata_store = _instantiate_metadata_store(metadata_store_cls, tmp_path)
+
+ yield metadata_store
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_initialize_empty(metadata_store):
+ metadata_store.initialize(run_id="test-run-id")
+ assert metadata_store.size == 0
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_not_empty_after_set(metadata_store):
+ cache_key = _mock_cache_key()
+ run_id = "test-run-id"
+ metadata_store.initialize(run_id=run_id)
+
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version="foo-a",
+ run_id=run_id,
+ )
+
+ assert metadata_store.size > 0
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_set_doesnt_produce_duplicates(metadata_store):
+ cache_key = _mock_cache_key()
+ data_version = "foo-a"
+ run_id = "test-run-id"
+ metadata_store.initialize(run_id=run_id)
+
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+ assert metadata_store.size == 1
+
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+ assert metadata_store.size == 1
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_get_miss_returns_none(metadata_store):
+ cache_key = _mock_cache_key()
+ run_id = "test-run-id"
+ metadata_store.initialize(run_id=run_id)
+
+ data_version = metadata_store.get(cache_key=cache_key)
+
+ assert data_version is None
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_set_and_get_with_empty_dependencies(metadata_store):
+ cache_key = _mock_cache_key()
+ data_version = "foo-a"
+ run_id = "test-run-id"
+ metadata_store.initialize(run_id=run_id)
+
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+ retrieved_data_version = metadata_store.get(cache_key=cache_key)
+
+ assert retrieved_data_version == data_version
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_get_run_ids_returns_ordered_list(metadata_store):
+ pre_run_ids = metadata_store.get_run_ids()
+ assert pre_run_ids == []
+
+ metadata_store.initialize(run_id="foo")
+ metadata_store.initialize(run_id="bar")
+ metadata_store.initialize(run_id="baz")
+
+ post_run_ids = metadata_store.get_run_ids()
+ assert post_run_ids == ["foo", "bar", "baz"]
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_get_run_results_include_cache_key_and_data_version(metadata_store):
+ cache_key = _mock_cache_key()
+ data_version = "foo-a"
+ run_id = "test-run-id"
+ metadata_store.initialize(run_id=run_id)
+
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+
+ run_info = metadata_store.get_run(run_id=run_id)
+
+ assert isinstance(run_info, list)
+ assert len(run_info) == 1
+ assert isinstance(run_info[0], dict)
+ assert run_info[0]["cache_key"] == cache_key
+ assert run_info[0]["data_version"] == data_version
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store):
+ run_id = "test-run-id"
+ metadata_store.initialize(run_id=run_id)
+ run_info = metadata_store.get_run(run_id=run_id)
+ assert run_info == []
+
+
+@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True)
+def test_get_run_raises_error_if_run_id_not_found(metadata_store):
+ metadata_store.initialize(run_id="test-run-id")
+ with pytest.raises(IndexError):
+ metadata_store.get_run(run_id="foo")
diff --git a/tests/caching/metadata_store/test_memory.py b/tests/caching/metadata_store/test_memory.py
new file mode 100644
index 000000000..d0f76d1b4
--- /dev/null
+++ b/tests/caching/metadata_store/test_memory.py
@@ -0,0 +1,60 @@
+import pytest
+
+from hamilton.caching.stores.memory import InMemoryMetadataStore
+from hamilton.caching.stores.sqlite import SQLiteMetadataStore
+
+# `metadata_store` is imported but not directly used because it's
+# a pytest fixture automatically provided to tests
+from .test_base import _mock_cache_key, metadata_store # noqa: F401
+
+# implementations that in-memory metadata store can `.persist_to()` and `.load_from()`
+PERSISTENT_IMPLEMENTATIONS = [SQLiteMetadataStore]
+
+
+@pytest.mark.parametrize("metadata_store", PERSISTENT_IMPLEMENTATIONS, indirect=True)
+def test_persist_to(metadata_store): # noqa: F811
+ cache_key = _mock_cache_key()
+ data_version = "foo-a"
+ run_id = "test-run-id"
+ in_memory_metadata_store = InMemoryMetadataStore()
+
+ # set values in-memory
+ in_memory_metadata_store.initialize(run_id=run_id)
+ in_memory_metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+
+ # values exist in memory, but not in destination
+ assert in_memory_metadata_store.get(cache_key) == data_version
+ assert metadata_store.get(cache_key) is None
+
+ # persist to destination
+ in_memory_metadata_store.persist_to(metadata_store)
+ assert metadata_store.get(cache_key) == data_version
+ assert in_memory_metadata_store.size == metadata_store.size
+ assert in_memory_metadata_store.get_run_ids() == metadata_store.get_run_ids()
+
+
+@pytest.mark.parametrize("metadata_store", PERSISTENT_IMPLEMENTATIONS, indirect=True)
+def test_load_from(metadata_store): # noqa: F811
+ cache_key = _mock_cache_key()
+ data_version = "foo-a"
+ run_id = "test-run-id"
+
+ # set values in source
+ metadata_store.initialize(run_id=run_id)
+ metadata_store.set(
+ cache_key=cache_key,
+ data_version=data_version,
+ run_id=run_id,
+ )
+
+ # values exist in source
+ assert metadata_store.get(cache_key) == data_version
+
+ in_memory_metadata_store = InMemoryMetadataStore.load_from(metadata_store)
+ assert in_memory_metadata_store.get(cache_key) == data_version
+ assert in_memory_metadata_store.size == metadata_store.size
+ assert in_memory_metadata_store.get_run_ids() == metadata_store.get_run_ids()
diff --git a/tests/caching/result_store/__init__.py b/tests/caching/result_store/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/caching/result_store/test_base.py b/tests/caching/result_store/test_base.py
new file mode 100644
index 000000000..c7b461e4f
--- /dev/null
+++ b/tests/caching/result_store/test_base.py
@@ -0,0 +1,28 @@
+import pytest
+
+from hamilton.caching.stores.file import FileResultStore
+from hamilton.caching.stores.memory import InMemoryResultStore
+
+
+def _instantiate_result_store(result_store_cls, tmp_path):
+ if result_store_cls == FileResultStore:
+ return FileResultStore(path=tmp_path)
+ elif result_store_cls == InMemoryResultStore:
+ return InMemoryResultStore()
+ else:
+ raise ValueError(
+ f"Class `{result_store_cls}` isn't defined in `_instantiate_metadata_store()`"
+ )
+
+
+@pytest.fixture
+def result_store(request, tmp_path):
+ result_store_cls = request.param
+ result_store = _instantiate_result_store(result_store_cls, tmp_path)
+
+ yield result_store
+
+ result_store.delete_all()
+
+
+# NOTE add tests that check properties shared across result store implementations below
diff --git a/tests/caching/result_store/test_file.py b/tests/caching/result_store/test_file.py
new file mode 100644
index 000000000..708a6a527
--- /dev/null
+++ b/tests/caching/result_store/test_file.py
@@ -0,0 +1,112 @@
+import pathlib
+import pickle
+import shutil
+
+import pytest
+
+from hamilton.caching import fingerprinting
+from hamilton.caching.stores.base import search_data_adapter_registry
+from hamilton.caching.stores.file import FileResultStore
+
+
+def _store_size(result_store: FileResultStore) -> int:
+ return len([p for p in result_store.path.iterdir()])
+
+
+@pytest.fixture
+def file_store(tmp_path):
+ store = FileResultStore(path=tmp_path)
+ assert _store_size(store) == 0
+ yield store
+
+ shutil.rmtree(tmp_path)
+
+
+def test_set(file_store):
+ data_version = "foo"
+
+ file_store.set(data_version=data_version, result="bar")
+
+ assert pathlib.Path(file_store.path, data_version).exists()
+ assert _store_size(file_store) == 1
+
+
+def test_exists(file_store):
+ data_version = "foo"
+ assert file_store.exists(data_version) == pathlib.Path(file_store.path, data_version).exists()
+
+ file_store.set(data_version=data_version, result="bar")
+
+ assert file_store.exists(data_version) == pathlib.Path(file_store.path, data_version).exists()
+
+
+def test_set_doesnt_produce_duplicates(file_store):
+ data_version = "foo"
+ assert not file_store.exists(data_version)
+
+ file_store.set(data_version=data_version, result="bar")
+ file_store.set(data_version=data_version, result="bar")
+
+ assert file_store.exists(data_version)
+ assert _store_size(file_store) == 1
+
+
+def test_get(file_store):
+ data_version = "foo"
+ result = "bar"
+ pathlib.Path(file_store.path, data_version).open("wb").write(pickle.dumps(result))
+ assert file_store.exists(data_version)
+
+ retrieved_value = file_store.get(data_version)
+
+ assert retrieved_value
+ assert result == retrieved_value
+ assert _store_size(file_store) == 1
+
+
+def test_get_missing_result_is_none(file_store):
+ result = file_store.get("foo")
+ assert result is None
+
+
+def test_delete(file_store):
+ data_version = "foo"
+ file_store.set(data_version, "bar")
+ assert pathlib.Path(file_store.path, data_version).exists()
+ assert _store_size(file_store) == 1
+
+ file_store.delete(data_version)
+
+ assert not pathlib.Path(file_store.path, data_version).exists()
+ assert _store_size(file_store) == 0
+
+
+def test_delete_all(file_store):
+ file_store.set("foo", "foo")
+ file_store.set("bar", "bar")
+ assert _store_size(file_store) == 2
+
+ file_store.delete_all()
+
+ assert _store_size(file_store) == 0
+
+
+@pytest.mark.parametrize(
+ "format,value",
+ [
+ ("json", {"key1": "value1", "key2": 2}),
+ ("pickle", ("value1", "value2", "value3")),
+ ],
+)
+def test_save_and_load_materializer(format, value, file_store):
+ saver_cls, loader_cls = search_data_adapter_registry(name=format, type_=type(value))
+ data_version = "foo"
+ materialized_path = file_store._materialized_path(data_version, saver_cls)
+
+ file_store.set(
+ data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls
+ )
+ retrieved_value = file_store.get(data_version)
+
+ assert materialized_path.exists()
+ assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value)
diff --git a/tests/caching/result_store/test_memory.py b/tests/caching/result_store/test_memory.py
new file mode 100644
index 000000000..7d7fc87e3
--- /dev/null
+++ b/tests/caching/result_store/test_memory.py
@@ -0,0 +1,131 @@
+import pytest
+
+from hamilton.caching.stores.base import StoredResult
+from hamilton.caching.stores.file import FileResultStore
+from hamilton.caching.stores.memory import InMemoryResultStore
+
+# `result_store` is imported but not directly used because it's
+# a pytest fixture automatically provided to tests
+from .test_base import result_store # noqa: F401
+
+# implementations that in-memory result store can `.persist_to()` and `.load_from()`
+PERSISTENT_IMPLEMENTATIONS = [FileResultStore]
+
+
+def _store_size(memory_store: InMemoryResultStore) -> int:
+ return len(memory_store._results)
+
+
+@pytest.fixture
+def memory_store():
+ store = InMemoryResultStore()
+ assert _store_size(store) == 0
+ yield store
+
+
+def test_set(memory_store):
+ data_version = "foo"
+
+ memory_store.set(data_version=data_version, result="bar")
+
+ assert memory_store._results[data_version].value == "bar"
+ assert _store_size(memory_store) == 1
+
+
+def test_exists(memory_store):
+ data_version = "foo"
+ assert memory_store.exists(data_version) is False
+
+ memory_store.set(data_version=data_version, result="bar")
+
+ assert memory_store.exists(data_version) is True
+
+
+def test_set_doesnt_produce_duplicates(memory_store):
+ data_version = "foo"
+ assert not memory_store.exists(data_version)
+
+ memory_store.set(data_version=data_version, result="bar")
+ memory_store.set(data_version=data_version, result="bar")
+
+ assert memory_store.exists(data_version)
+ assert _store_size(memory_store) == 1
+
+
+def test_get(memory_store):
+ data_version = "foo"
+ result = StoredResult(value="bar")
+ memory_store._results[data_version] = result
+ assert memory_store.exists(data_version)
+
+ retrieved_value = memory_store.get(data_version)
+
+ assert retrieved_value is not None
+ assert result.value == retrieved_value
+ assert _store_size(memory_store) == 1
+
+
+def test_get_missing_result_is_none(memory_store):
+ result = memory_store.get("foo")
+ assert result is None
+
+
+def test_delete(memory_store):
+ data_version = "foo"
+ memory_store._results[data_version] = StoredResult(value="bar")
+ assert _store_size(memory_store) == 1
+
+ memory_store.delete(data_version)
+
+ assert memory_store._results.get(data_version) is None
+ assert _store_size(memory_store) == 0
+
+
+def test_delete_all(memory_store):
+ memory_store._results["foo"] = "foo"
+ memory_store._results["bar"] = "bar"
+ assert _store_size(memory_store) == 2
+
+ memory_store.delete_all()
+
+ assert _store_size(memory_store) == 0
+
+
+@pytest.mark.parametrize("result_store", PERSISTENT_IMPLEMENTATIONS, indirect=True)
+def test_persist_to(result_store, memory_store): # noqa: F811
+ data_version = "foo"
+ result = "bar"
+
+ # set values in-memory
+ memory_store.set(data_version=data_version, result=result)
+
+ # values exist in memory, but not in destination
+ assert memory_store.get(data_version) == result
+ assert result_store.get(data_version) is None
+
+ # persist to destination
+ memory_store.persist_to(result_store)
+ assert memory_store.get(data_version) == result_store.get(data_version)
+
+
+@pytest.mark.parametrize("result_store", PERSISTENT_IMPLEMENTATIONS, indirect=True)
+def test_load_from(result_store): # noqa: F811
+ data_version = "foo"
+ result = "bar"
+
+ # set values in source
+ result_store.set(data_version=data_version, result=result)
+
+ # values exist in source
+ assert result_store.get(data_version) == result
+
+ memory_store = InMemoryResultStore.load_from(
+ result_store=result_store, data_versions=[data_version]
+ )
+ assert memory_store.get(data_version) == result_store.get(data_version)
+
+
+def test_load_from_must_have_metadata_store_or_data_versions(tmp_path):
+ file_result_store = FileResultStore(tmp_path)
+ with pytest.raises(ValueError):
+ InMemoryResultStore.load_from(result_store=file_result_store)
diff --git a/tests/caching/test_integration.py b/tests/caching/test_integration.py
index cb14c33e1..61315ce34 100644
--- a/tests/caching/test_integration.py
+++ b/tests/caching/test_integration.py
@@ -1,3 +1,4 @@
+import itertools
from typing import List
import pandas as pd
@@ -5,6 +6,9 @@
from hamilton import ad_hoc_utils, driver
from hamilton.caching.adapter import CachingEventType, HamiltonCacheAdapter
+from hamilton.caching.stores.file import FileResultStore
+from hamilton.caching.stores.memory import InMemoryMetadataStore, InMemoryResultStore
+from hamilton.caching.stores.sqlite import SQLiteMetadataStore
from hamilton.execution.executors import (
MultiProcessingExecutor,
MultiThreadingExecutor,
@@ -12,9 +16,30 @@
)
from hamilton.function_modifiers import cache as cache_decorator
+# `metadata_store` and `result_store` are imported but not directly used because they
+# are pytest fixtures automatically provided to tests
+from .metadata_store.test_base import metadata_store # noqa: F401
+from .result_store.test_base import result_store # noqa: F401
from tests.resources.dynamic_parallelism import parallel_linear_basic, parallelism_with_caching
+def _instantiate_executor(executor_cls):
+ if executor_cls == SynchronousLocalTaskExecutor:
+ return SynchronousLocalTaskExecutor()
+ elif executor_cls == MultiProcessingExecutor:
+ return MultiProcessingExecutor(max_tasks=10)
+ elif executor_cls == MultiThreadingExecutor:
+ return MultiThreadingExecutor(max_tasks=10)
+ else:
+ raise ValueError(f"Class `{executor_cls}` isn't defined in `_instantiate_executor()`")
+
+
+@pytest.fixture
+def executor(request):
+ executor_cls = request.param
+ return _instantiate_executor(executor_cls)
+
+
@pytest.fixture
def dr(request, tmp_path):
module = request.param
@@ -483,19 +508,34 @@ def foo() -> dict:
assert result[node_name] == retrieved_result
+EXECUTORS_AND_STORES_CONFIGURATIONS = list(
+ itertools.product(
+ [SynchronousLocalTaskExecutor, MultiThreadingExecutor, MultiProcessingExecutor],
+ [SQLiteMetadataStore],
+ [FileResultStore],
+ )
+)
+
+# InMemory stores can't be used with multiprocessing because they don't share memory.
+IN_MEMORY_CONFIGURATIONS = list(
+ itertools.product(
+ [SynchronousLocalTaskExecutor, MultiThreadingExecutor],
+ [InMemoryMetadataStore, SQLiteMetadataStore],
+ [InMemoryResultStore, FileResultStore],
+ )
+)
+
+EXECUTORS_AND_STORES_CONFIGURATIONS += IN_MEMORY_CONFIGURATIONS
+
+
@pytest.mark.parametrize(
- "executor",
- [
- SynchronousLocalTaskExecutor(),
- MultiProcessingExecutor(max_tasks=10),
- MultiThreadingExecutor(max_tasks=10),
- ],
+ "executor,metadata_store,result_store", EXECUTORS_AND_STORES_CONFIGURATIONS, indirect=True
)
-def test_parallel_synchronous_step_by_step(tmp_path, executor):
+def test_parallel_synchronous_step_by_step(executor, metadata_store, result_store): # noqa: F811
dr = (
driver.Builder()
.with_modules(parallel_linear_basic)
- .with_cache(path=tmp_path)
+ .with_cache(metadata_store=metadata_store, result_store=result_store)
.enable_dynamic_execution(allow_experimental_mode=True)
.with_remote_executor(executor)
.build()
@@ -559,11 +599,8 @@ def test_parallel_synchronous_step_by_step(tmp_path, executor):
@pytest.mark.parametrize(
"executor",
- [
- SynchronousLocalTaskExecutor(),
- MultiProcessingExecutor(max_tasks=10),
- MultiThreadingExecutor(max_tasks=10),
- ],
+ [SynchronousLocalTaskExecutor, MultiProcessingExecutor, MultiThreadingExecutor],
+ indirect=True,
)
def test_materialize_parallel_branches(tmp_path, executor):
# NOTE the module can't be defined here because multithreading requires functions to be top-level.
@@ -596,7 +633,7 @@ def test_materialize_parallel_branches(tmp_path, executor):
)
-def test_consistent_cache_key_with_or_without_defaut_parameter(tmp_path):
+def test_consistent_cache_key_with_or_without_default_parameter(tmp_path):
def foo(external_dep: int = 3) -> int:
return external_dep + 1
diff --git a/tests/caching/test_metadata_store.py b/tests/caching/test_metadata_store.py
deleted file mode 100644
index 156e44b3e..000000000
--- a/tests/caching/test_metadata_store.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import pytest
-
-from hamilton.caching.cache_key import create_cache_key
-from hamilton.caching.stores.sqlite import SQLiteMetadataStore
-
-
-@pytest.fixture
-def metadata_store(request, tmp_path):
- metdata_store_cls = request.param
- metadata_store = metdata_store_cls(path=tmp_path)
- run_id = "test-run-id"
- try:
- metadata_store.initialize(run_id)
- except BaseException:
- pass
-
- yield metadata_store
-
- metadata_store.delete_all()
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_initialize_empty(metadata_store):
- assert metadata_store.size == 0
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_not_empty_after_set(metadata_store):
- code_version = "FOO-1"
- data_version = "foo-a"
- node_name = "foo"
- cache_key = create_cache_key(
- node_name=node_name, code_version=code_version, dependencies_data_versions={}
- )
-
- metadata_store.set(
- cache_key=cache_key,
- node_name=node_name,
- code_version=code_version,
- data_version=data_version,
- run_id="...",
- )
-
- assert metadata_store.size > 0
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_set_doesnt_produce_duplicates(metadata_store):
- code_version = "FOO-1"
- data_version = "foo-a"
- node_name = "foo"
- cache_key = create_cache_key(
- node_name=node_name, code_version=code_version, dependencies_data_versions={}
- )
- metadata_store.set(
- cache_key=cache_key,
- node_name=node_name,
- code_version=code_version,
- data_version=data_version,
- run_id="...",
- )
- assert metadata_store.size == 1
-
- metadata_store.set(
- cache_key=cache_key,
- node_name=node_name,
- code_version=code_version,
- data_version=data_version,
- run_id="...",
- )
- assert metadata_store.size == 1
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_get_miss_returns_none(metadata_store):
- cache_key = create_cache_key(
- node_name="foo", code_version="FOO-1", dependencies_data_versions={"bar": "bar-a"}
- )
- data_version = metadata_store.get(cache_key=cache_key)
- assert data_version is None
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_set_get_without_dependencies(metadata_store):
- code_version = "FOO-1"
- data_version = "foo-a"
- node_name = "foo"
- cache_key = create_cache_key(
- node_name=node_name, code_version=code_version, dependencies_data_versions={}
- )
- metadata_store.set(
- cache_key=cache_key,
- node_name=node_name,
- code_version=code_version,
- data_version=data_version,
- run_id="...",
- )
- retrieved_data_version = metadata_store.get(cache_key=cache_key)
-
- assert retrieved_data_version == data_version
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_get_run_ids_returns_ordered_list(metadata_store):
- pre_run_ids = metadata_store.get_run_ids()
- assert pre_run_ids == ["test-run-id"] # this is from the fixture
-
- metadata_store.initialize(run_id="foo")
- metadata_store.initialize(run_id="bar")
- metadata_store.initialize(run_id="baz")
-
- post_run_ids = metadata_store.get_run_ids()
- assert post_run_ids == ["test-run-id", "foo", "bar", "baz"]
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_get_run_results_include_cache_key_and_data_version(metadata_store):
- run_id = "test-run-id"
- metadata_store.set(
- cache_key="foo",
- data_version="1",
- run_id=run_id,
- node_name="a", # kwarg specific to SQLiteMetadataStore
- code_version="b", # kwarg specific to SQLiteMetadataStore
- )
- metadata_store.set(
- cache_key="bar",
- data_version="2",
- run_id=run_id,
- node_name="a", # kwarg specific to SQLiteMetadataStore
- code_version="b", # kwarg specific to SQLiteMetadataStore
- )
-
- run_info = metadata_store.get_run(run_id=run_id)
-
- assert isinstance(run_info, list)
- assert len(run_info) == 2
- assert isinstance(run_info[1], dict)
- assert run_info[0]["cache_key"] == "foo"
- assert run_info[0]["data_version"] == "1"
- assert run_info[1]["cache_key"] == "bar"
- assert run_info[1]["data_version"] == "2"
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store):
- metadata_store.initialize(run_id="foo")
- run_info = metadata_store.get_run(run_id="foo")
- assert run_info == []
-
-
-@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
-def test_get_run_raises_error_if_run_id_not_found(metadata_store):
- with pytest.raises(IndexError):
- metadata_store.get_run(run_id="foo")