diff --git a/examples/caching/README.md b/examples/caching/README.md index 4b53b2218..82a038bf3 100644 --- a/examples/caching/README.md +++ b/examples/caching/README.md @@ -3,4 +3,5 @@ This directory contains tutorial notebooks for the Hamilton caching feature. - `tutorial.ipynb`: the main tutorial for caching -- `materializer_tutorial.ipynb`: tutorial on the interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial for materializer users. You should complete the `tutorial.ipynb` first. +- `in_memory_tutorial.ipynb`: How to use caching without writing metadata and results to file. +- `materializer_tutorial.ipynb`: Learn interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial. diff --git a/examples/caching/in_memory_tutorial.ipynb b/examples/caching/in_memory_tutorial.ipynb new file mode 100644 index 000000000..712422a22 --- /dev/null +++ b/examples/caching/in_memory_tutorial.ipynb @@ -0,0 +1,673 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# In-memory caching tutorial\n", + "\n", + "This notebook shows how to use in-memory caching, which allows to cache results between runs without writing to disk. This uses the `InMemoryResultStore` and `InMemoryMetadataStore` classes.\n", + "\n", + "> ⛔ In-memory caching can consume a lot of memory if you're using storing large results. Selectively caching nodes is recommended.\n", + "\n", + "If you're new to caching, you should take a look at the [caching tutorial](./tutorial.ipynb) first!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "Throughout this tutorial, we'll be using the Hamilton notebook extension to define dataflows directly in the notebook ([see tutorial](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/jupyter_notebook_magic/example.ipynb)).\n", + "\n", + "Then, we get the logger for caching and clear previously cached results." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import shutil\n", + "\n", + "# avoid loading all available plugins for fast startup time\n", + "from hamilton import registry\n", + "registry.disable_autoload()\n", + "registry.load_extension(\"pandas\")\n", + "\n", + "from hamilton import driver\n", + "\n", + "# load the notebook extension\n", + "%reload_ext hamilton.plugins.jupyter_magic\n", + "\n", + "logger = logging.getLogger(\"hamilton.caching\")\n", + "logger.setLevel(logging.INFO)\n", + "logger.addHandler(logging.StreamHandler())\n", + "\n", + "shutil.rmtree(\"./.hamilton_cache\", ignore_errors=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a dataflow\n", + "We define a simple dataflow that loads a dataframe of transactions, filters by date, converts currency to USD, and sums the amount per country." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module dataflow_module --display\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18 # <- LINE ADDED\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05 # <- LINE ADDED\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## In-memory caching\n", + "To use in-memory caching, pass `InMemoryResultStore` and `InMemoryMetadataStore` to `Builder().with_cache()`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton.caching.stores.memory import InMemoryMetadataStore, InMemoryResultStore\n", + "\n", + "dr = (\n", + " driver.Builder()\n", + " .with_modules(dataflow_module)\n", + " .with_cache(\n", + " result_store=InMemoryResultStore(),\n", + " metadata_store=InMemoryMetadataStore(),\n", + " )\n", + " .build()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Execution 1\n", + "For execution 1, we see that all nodes are executed." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(results[\"processed_data\"].head())\n", + "dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Execution 2\n", + "For execution 2, we see that all nodes are retrieved from cache." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(results[\"processed_data\"].head())\n", + "dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Persisting in-memory data\n", + "\n", + "Now, we import `SQLiteMetadataStore` and `FileResultStore` to persist the data to disk. We access the in-memory stores via `dr.cache.result_store` and `dr.cache.metadata_store` and call the `.persist_to()` method on each.\n", + "\n", + "After executing the cell, you should see a new directory `./.persisted_cache` with results and metadata." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton.caching.stores.sqlite import SQLiteMetadataStore\n", + "from hamilton.caching.stores.file import FileResultStore\n", + "\n", + "path = \"./.persisted_cache\"\n", + "on_disk_results = FileResultStore(path=path)\n", + "on_disk_metadata = SQLiteMetadataStore(path=path)\n", + "\n", + "dr.cache.result_store.persist_to(on_disk_results)\n", + "dr.cache.metadata_store.persist_to(on_disk_metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading persisted data\n", + "\n", + "Now, we create a new `Driver`. Instead of starting with empty in-memory stores, we will load the previously persisted results by calling `.load_from()` on the `InMemoryResultStore` and `InMemoryMetadataStore` classes.\n", + "\n", + "For `InMemoryResultStore.load_from()`, we must provide a `MetadataStore` or a list of `data_version` to load results for." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "dr = (\n", + " driver.Builder()\n", + " .with_modules(dataflow_module)\n", + " .with_cache(\n", + " result_store=InMemoryResultStore.load_from(\n", + " on_disk_results,\n", + " metadata_store=on_disk_metadata,\n", + " ),\n", + " metadata_store=InMemoryMetadataStore.load_from(on_disk_metadata),\n", + " )\n", + " .build()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We print the size of the metadata store to show it contains 2 entries (one for `raw_data` and another for `processed_data`). Also, we see that results load from `FileResultStore`are successfully retrieved from the in-memory stores." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n", + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(dr.cache.metadata_store.size)\n", + "\n", + "results = dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(results[\"processed_data\"].head())\n", + "dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use cases\n", + "\n", + "In-memory caching can be useful when you're doing a lot of experimentation in a notebook or an interactive session and don't want to persist results for future use. \n", + "\n", + "It can also speed up execution in some cases because you're no longer doing read/write to disk for each node execution." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/hamilton/caching/adapter.py b/hamilton/caching/adapter.py index c41bfa4fd..59fd18139 100644 --- a/hamilton/caching/adapter.py +++ b/hamilton/caching/adapter.py @@ -274,7 +274,9 @@ def __getstate__(self) -> dict: state = self.__dict__.copy() # store the classes to reinstantiate the same backend in __setstate__ state["metadata_store_cls"] = self.metadata_store.__class__ + state["metadata_store_init"] = self.metadata_store.__getstate__() state["result_store_cls"] = self.result_store.__class__ + state["result_store_init"] = self.result_store.__getstate__() del state["metadata_store"] del state["result_store"] return state @@ -288,8 +290,8 @@ def __setstate__(self, state: dict) -> None: """ # instantiate the backend from the class, then delete the attribute before # setting it on the adapter instance. - self.metadata_store = state["metadata_store_cls"](path=state["_path"]) - self.result_store = state["result_store_cls"](path=state["_path"]) + self.metadata_store = state["metadata_store_cls"](**state["metadata_store_init"]) + self.result_store = state["result_store_cls"](**state["result_store_init"]) del state["metadata_store_cls"] del state["result_store_cls"] self.__dict__.update(state) diff --git a/hamilton/caching/stores/file.py b/hamilton/caching/stores/file.py index d483ca292..34144d9c1 100644 --- a/hamilton/caching/stores/file.py +++ b/hamilton/caching/stores/file.py @@ -2,9 +2,15 @@ from pathlib import Path from typing import Any, Optional -from hamilton.caching.stores.base import ResultStore, StoredResult +try: + from typing import override +except ImportError: + override = lambda x: x # noqa E731 + from hamilton.io.data_adapters import DataLoader, DataSaver +from .base import ResultStore, StoredResult + class FileResultStore(ResultStore): def __init__(self, path: str, create_dir: bool = True) -> None: @@ -14,6 +20,12 @@ def __init__(self, path: str, create_dir: bool = True) -> None: if self.create_dir: self.path.mkdir(exist_ok=True, parents=True) + def __getstate__(self) -> dict: + """Serialize the `__init__` kwargs to pass in Parallelizable branches + when using multiprocessing. + """ + return {"path": str(self.path)} + @staticmethod def _write_result(file_path: Path, stored_result: StoredResult) -> None: file_path.write_bytes(stored_result.save()) @@ -33,10 +45,12 @@ def _materialized_path(self, data_version: str, saver_cls: DataSaver) -> Path: # TODO allow a more flexible mechanism to specify file path extension return self._path_from_data_version(data_version).with_suffix(f".{saver_cls.name()}") + @override def exists(self, data_version: str) -> bool: result_path = self._path_from_data_version(data_version) return result_path.exists() + @override def set( self, data_version: str, @@ -65,6 +79,7 @@ def set( stored_result = StoredResult.new(value=result, saver=saver, loader=loader) self._write_result(result_path, stored_result) + @override def get(self, data_version: str) -> Optional[Any]: result_path = self._path_from_data_version(data_version) stored_result = self._load_result_from_path(result_path) @@ -74,10 +89,12 @@ def get(self, data_version: str) -> Optional[Any]: return stored_result.value + @override def delete(self, data_version: str) -> None: result_path = self._path_from_data_version(data_version) result_path.unlink(missing_ok=True) + @override def delete_all(self) -> None: shutil.rmtree(self.path) self.path.mkdir(exist_ok=True) diff --git a/hamilton/caching/stores/memory.py b/hamilton/caching/stores/memory.py new file mode 100644 index 000000000..37eee89d5 --- /dev/null +++ b/hamilton/caching/stores/memory.py @@ -0,0 +1,279 @@ +from typing import Any, Dict, List, Optional, Sequence + +try: + from typing import override +except ImportError: + override = lambda x: x # noqa E731 + +from hamilton.caching.cache_key import decode_key + +from .base import MetadataStore, ResultStore, StoredResult +from .file import FileResultStore +from .sqlite import SQLiteMetadataStore + + +class InMemoryMetadataStore(MetadataStore): + def __init__(self) -> None: + self._data_versions: Dict[str, str] = {} # {cache_key: data_version} + self._cache_keys_by_run: Dict[str, List[str]] = {} # {run_id: [cache_key]} + self._run_ids: List[str] = [] + + @override + def __len__(self) -> int: + """Number of unique ``cache_key`` values.""" + return len(self._data_versions.keys()) + + @override + def exists(self, cache_key: str) -> bool: + """Indicate if ``cache_key`` exists and it can retrieve a ``data_version``.""" + return cache_key in self._data_versions.keys() + + @override + def initialize(self, run_id: str) -> None: + """Set up and log the beginning of the run.""" + self._cache_keys_by_run[run_id] = [] + self._run_ids.append(run_id) + + @override + def set(self, cache_key: str, data_version: str, run_id: str, **kwargs) -> Optional[Any]: + """Set the ``data_version`` for ``cache_key`` and associate it with the ``run_id``.""" + self._data_versions[cache_key] = data_version + self._cache_keys_by_run[run_id].append(cache_key) + + @override + def get(self, cache_key: str) -> Optional[str]: + """Retrieve the ``data_version`` for ``cache_key``.""" + return self._data_versions.get(cache_key, None) + + @override + def delete(self, cache_key: str) -> None: + """Delete the ``data_version`` for ``cache_key``.""" + del self._data_versions[cache_key] + + @override + def delete_all(self) -> None: + """Delete all stored metadata.""" + self._data_versions.clear() + + def persist_to(self, metadata_store: Optional[MetadataStore] = None) -> None: + """Persist in-memory metadata using another MetadataStore implementation. + + :param metadata_store: MetadataStore implementation to use for persistence. + If None, a SQLiteMetadataStore is created with the default path "./.hamilton_cache". + + .. code-block:: python + + from hamilton import driver + from hamilton.caching.stores.sqlite import SQLiteMetadataStore + from hamilton.caching.stores.memory import InMemoryMetadataStore + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache(metadata_store=InMemoryMetadataStore()) + .build() + ) + + # execute the Driver several time. This will populate the in-memory metadata store + dr.execute(...) + + # persist to disk in-memory metadata + dr.cache.metadata_store.persist_to(SQLiteMetadataStore(path="./.hamilton_cache")) + + """ + if metadata_store is None: + metadata_store = SQLiteMetadataStore(path="./.hamilton_cache") + + for run_id in self._run_ids: + metadata_store.initialize(run_id) + + for run_id, cache_keys in self._cache_keys_by_run.items(): + for cache_key in cache_keys: + data_version = self._data_versions[cache_key] + metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + + @classmethod + def load_from(cls, metadata_store: MetadataStore) -> "InMemoryMetadataStore": + """Load in-memory metadata from another MetadataStore instance. + + :param metadata_store: MetadataStore instance to load from. + :return: InMemoryMetadataStore copy of the ``metadata_store``. + + .. code-block:: python + + from hamilton import driver + from hamilton.caching.stores.sqlite import SQLiteMetadataStore + from hamilton.caching.stores.memory import InMemoryMetadataStore + import my_dataflow + + sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache") + in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store) + + # create the Driver with the in-memory metadata store + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache(metadata_store=in_memory_metadata_store) + .build() + ) + + """ + in_memory_metadata_store = InMemoryMetadataStore() + + for run_id in metadata_store.get_run_ids(): + in_memory_metadata_store.initialize(run_id) + + for node_metadata in metadata_store.get_run(run_id): + in_memory_metadata_store.set( + cache_key=node_metadata["cache_key"], + data_version=node_metadata["data_version"], + run_id=run_id, + ) + + return in_memory_metadata_store + + @override + def get_run_ids(self) -> List[str]: + """Return a list of all ``run_id`` values stored.""" + return self._run_ids + + @override + def get_run(self, run_id: str) -> List[Dict[str, str]]: + """Return a list of node metadata associated with a run.""" + if self._cache_keys_by_run.get(run_id, None) is None: + raise IndexError(f"Run ID not found: {run_id}") + + nodes_metadata = [] + for cache_key in self._cache_keys_by_run[run_id]: + decoded_key = decode_key(cache_key) + nodes_metadata.append( + dict( + cache_key=cache_key, + data_version=self._data_versions[cache_key], + node_name=decoded_key["node_name"], + code_version=decoded_key["code_version"], + dependencies_data_versions=decoded_key["dependencies_data_versions"], + ) + ) + + return nodes_metadata + + +class InMemoryResultStore(ResultStore): + def __init__(self, persist_on_exit: bool = False) -> None: + self._results: Dict[str, StoredResult] = {} # {data_version: result} + + @override + def exists(self, data_version: str) -> bool: + return data_version in self._results.keys() + + # TODO handle materialization + @override + def set(self, data_version: str, result: Any, **kwargs) -> None: + self._results[data_version] = StoredResult.new(value=result) + + @override + def get(self, data_version: str) -> Optional[Any]: + stored_result = self._results.get(data_version, None) + if stored_result is None: + return None + + return stored_result.value + + @override + def delete(self, data_version: str) -> None: + del self._results[data_version] + + @override + def delete_all(self) -> None: + self._results.clear() + + def delete_expired(self) -> None: + to_delete = [ + data_version + for data_version, stored_result in self._results.items() + if stored_result.expired + ] + + # first collect keys then delete because you can delete from dictionary + # as you iterate through it + for data_version in to_delete: + self.delete(data_version) + + def persist_to(self, result_store: Optional[ResultStore] = None) -> None: + """Persist in-memory results using another ``ResultStore`` implementation. + + :param result_store: ResultStore implementation to use for persistence. + If None, a FileResultStore is created with the default path "./.hamilton_cache". + """ + if result_store is None: + result_store = FileResultStore(path="./.hamilton_cache") + + for data_version, stored_result in self._results.items(): + result_store.set(data_version, stored_result.value) + + @classmethod + def load_from( + cls, + result_store: ResultStore, + metadata_store: Optional[MetadataStore] = None, + data_versions: Optional[Sequence[str]] = None, + ) -> "InMemoryResultStore": + """Load in-memory results from another ResultStore instance. + + Since result stores do not store an index of their keys, you must provide a + ``MetadataStore`` instance or a list of ``data_version`` for which results + should be loaded in memory. + + :param result_store: ``ResultStore`` instance to load results from. + :param metadata_store: ``MetadataStore`` instance from which all ``data_version`` are retrieved. + :return: InMemoryResultStore copy of the ``result_store``. + + .. code-block:: python + + from hamilton import driver + from hamilton.caching.stores.sqlite import SQLiteMetadataStore + from hamilton.caching.stores.memory import InMemoryMetadataStore + import my_dataflow + + sqlite_metadata_store = SQLiteMetadataStore(path="./.hamilton_cache") + in_memory_metadata_store = InMemoryMetadataStore.load_from(sqlite_metadata_store) + + # create the Driver with the in-memory metadata store + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache(metadata_store=in_memory_metadata_store) + .build() + ) + + + """ + if metadata_store is None and data_versions is None: + raise ValueError( + "A `metadata_store` or `data_versions` must be provided to load results." + ) + + in_memory_result_store = InMemoryResultStore() + + data_versions_to_retrieve = set() + if data_versions is not None: + data_versions_to_retrieve.update(data_versions) + + if metadata_store is not None: + for run_id in metadata_store.get_run_ids(): + for node_metadata in metadata_store.get_run(run_id): + data_versions_to_retrieve.add(node_metadata["data_version"]) + + for data_version in data_versions_to_retrieve: + # TODO disambiguate "result is None" from the sentinel value when `data_version` + # is not found in `result_store`. + result = result_store.get(data_version) + in_memory_result_store.set(data_version, result) + + return in_memory_result_store diff --git a/hamilton/caching/stores/sqlite.py b/hamilton/caching/stores/sqlite.py index 4ad0eb480..1e434844b 100644 --- a/hamilton/caching/stores/sqlite.py +++ b/hamilton/caching/stores/sqlite.py @@ -20,6 +20,20 @@ def __init__( self._thread_local = threading.local() + # creating tables at `__init__` prevents other methods from encountering + # `sqlite3.OperationalError` because tables are missing. + self._create_tables_if_not_exists() + + def __getstate__(self) -> dict: + """Serialized `__init__` arguments required to initialize the + MetadataStore in a new thread or process. + """ + state = {} + # NOTE kwarg `path` is not equivalent to `self._path` + state["path"] = self._directory + state["connection_kwargs"] = self.connection_kwargs + return state + def _get_connection(self) -> sqlite3.Connection: if not hasattr(self._thread_local, "connection"): self._thread_local.connection = sqlite3.connect( @@ -33,14 +47,15 @@ def _close_connection(self) -> None: del self._thread_local.connection @property - def connection(self): + def connection(self) -> sqlite3.Connection: + """Connection to the SQLite database.""" return self._get_connection() def __del__(self): """Close the SQLite connection when the object is deleted""" self._close_connection() - def _create_tables_if_not_exists(self): + def _create_tables_if_not_exists(self) -> None: """Create the tables necessary for the cache: run_ids: queue of run_ids, ordered by start time. @@ -92,12 +107,11 @@ def initialize(self, run_id) -> None: """Call initialize when starting a run. This will create database tables if necessary. """ - self._create_tables_if_not_exists() cur = self.connection.cursor() cur.execute("INSERT INTO run_ids (run_id) VALUES (?)", (run_id,)) self.connection.commit() - def __len__(self): + def __len__(self) -> int: """Number of entries in cache_metadata""" cur = self.connection.cursor() cur.execute("SELECT COUNT(*) FROM cache_metadata") @@ -118,7 +132,15 @@ def set( # if the caller of ``.set()`` directly provides the ``node_name`` and ``code_version``, # we can skip the decoding step. if (node_name is None) or (code_version is None): - decoded_key = decode_key(cache_key) + try: + decoded_key = decode_key(cache_key) + except BaseException as e: + raise ValueError( + f"Failed decoding the cache_key: {cache_key}.\n", + "The `cache_key` must be created by `hamilton.caching.cache_key.create_cache_key()` ", + "if `node_name` and `code_version` are not provided.", + ) from e + node_name = decoded_key["node_name"] code_version = decoded_key["code_version"] diff --git a/hamilton/caching/stores/utils.py b/hamilton/caching/stores/utils.py index 1a69b41eb..e4df5203f 100644 --- a/hamilton/caching/stores/utils.py +++ b/hamilton/caching/stores/utils.py @@ -2,6 +2,7 @@ def get_directory_size(directory: str) -> float: + """Get the size of the content of a directory in bytes.""" total_size = 0 for p in pathlib.Path(directory).rglob("*"): if p.is_file(): @@ -11,6 +12,7 @@ def get_directory_size(directory: str) -> float: def readable_bytes_size(n_bytes: float) -> str: + """Convert a number of bytes to a human-readable unit.""" labels = ["B", "KB", "MB", "GB", "TB"] exponent = 0 diff --git a/tests/caching/metadata_store/__init__.py b/tests/caching/metadata_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/caching/metadata_store/test_base.py b/tests/caching/metadata_store/test_base.py new file mode 100644 index 000000000..0e6078bb6 --- /dev/null +++ b/tests/caching/metadata_store/test_base.py @@ -0,0 +1,173 @@ +from typing import Dict + +import pytest + +from hamilton.caching.cache_key import create_cache_key +from hamilton.caching.fingerprinting import hash_value +from hamilton.caching.stores.memory import InMemoryMetadataStore +from hamilton.caching.stores.sqlite import SQLiteMetadataStore +from hamilton.graph_types import hash_source_code + +# if you're adding a new `MetadataStore` implementation, add it to this list. +# Your implementation should successfully pass all the tests. +# Implementation-specific tests should be added to `test_{implementation}.py`. +IMPLEMENTATIONS = [SQLiteMetadataStore, InMemoryMetadataStore] + + +def _instantiate_metadata_store(metadata_store_cls, tmp_path): + if metadata_store_cls == SQLiteMetadataStore: + return SQLiteMetadataStore(path=tmp_path) + elif metadata_store_cls == InMemoryMetadataStore: + return InMemoryMetadataStore() + else: + raise ValueError( + f"Class `{metadata_store_cls}` isn't defined in `_instantiate_metadata_store()`" + ) + + +def _mock_cache_key( + node_name: str = "foo", + code_version: str = "FOO-1", + dependencies_data_versions: Dict[str, str] = None, +) -> str: + """Utility to create a valid cache key from mock values. + This is helpful because ``code_version`` and ``data_version`` found in ``dependencies_data_versions`` + must respect specific encoding. + """ + dependencies_data_versions = ( + dependencies_data_versions if dependencies_data_versions is not None else {} + ) + return create_cache_key( + node_name=node_name, + code_version=hash_source_code(code_version), + dependencies_data_versions={k: hash_value(v) for k, v in dependencies_data_versions}, + ) + + +@pytest.fixture +def metadata_store(request, tmp_path): + metadata_store_cls = request.param + metadata_store = _instantiate_metadata_store(metadata_store_cls, tmp_path) + + yield metadata_store + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_initialize_empty(metadata_store): + metadata_store.initialize(run_id="test-run-id") + assert metadata_store.size == 0 + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_not_empty_after_set(metadata_store): + cache_key = _mock_cache_key() + run_id = "test-run-id" + metadata_store.initialize(run_id=run_id) + + metadata_store.set( + cache_key=cache_key, + data_version="foo-a", + run_id=run_id, + ) + + assert metadata_store.size > 0 + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_set_doesnt_produce_duplicates(metadata_store): + cache_key = _mock_cache_key() + data_version = "foo-a" + run_id = "test-run-id" + metadata_store.initialize(run_id=run_id) + + metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + assert metadata_store.size == 1 + + metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + assert metadata_store.size == 1 + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_get_miss_returns_none(metadata_store): + cache_key = _mock_cache_key() + run_id = "test-run-id" + metadata_store.initialize(run_id=run_id) + + data_version = metadata_store.get(cache_key=cache_key) + + assert data_version is None + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_set_and_get_with_empty_dependencies(metadata_store): + cache_key = _mock_cache_key() + data_version = "foo-a" + run_id = "test-run-id" + metadata_store.initialize(run_id=run_id) + + metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + retrieved_data_version = metadata_store.get(cache_key=cache_key) + + assert retrieved_data_version == data_version + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_get_run_ids_returns_ordered_list(metadata_store): + pre_run_ids = metadata_store.get_run_ids() + assert pre_run_ids == [] + + metadata_store.initialize(run_id="foo") + metadata_store.initialize(run_id="bar") + metadata_store.initialize(run_id="baz") + + post_run_ids = metadata_store.get_run_ids() + assert post_run_ids == ["foo", "bar", "baz"] + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_get_run_results_include_cache_key_and_data_version(metadata_store): + cache_key = _mock_cache_key() + data_version = "foo-a" + run_id = "test-run-id" + metadata_store.initialize(run_id=run_id) + + metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + + run_info = metadata_store.get_run(run_id=run_id) + + assert isinstance(run_info, list) + assert len(run_info) == 1 + assert isinstance(run_info[0], dict) + assert run_info[0]["cache_key"] == cache_key + assert run_info[0]["data_version"] == data_version + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store): + run_id = "test-run-id" + metadata_store.initialize(run_id=run_id) + run_info = metadata_store.get_run(run_id=run_id) + assert run_info == [] + + +@pytest.mark.parametrize("metadata_store", IMPLEMENTATIONS, indirect=True) +def test_get_run_raises_error_if_run_id_not_found(metadata_store): + metadata_store.initialize(run_id="test-run-id") + with pytest.raises(IndexError): + metadata_store.get_run(run_id="foo") diff --git a/tests/caching/metadata_store/test_memory.py b/tests/caching/metadata_store/test_memory.py new file mode 100644 index 000000000..d0f76d1b4 --- /dev/null +++ b/tests/caching/metadata_store/test_memory.py @@ -0,0 +1,60 @@ +import pytest + +from hamilton.caching.stores.memory import InMemoryMetadataStore +from hamilton.caching.stores.sqlite import SQLiteMetadataStore + +# `metadata_store` is imported but not directly used because it's +# a pytest fixture automatically provided to tests +from .test_base import _mock_cache_key, metadata_store # noqa: F401 + +# implementations that in-memory metadata store can `.persist_to()` and `.load_from()` +PERSISTENT_IMPLEMENTATIONS = [SQLiteMetadataStore] + + +@pytest.mark.parametrize("metadata_store", PERSISTENT_IMPLEMENTATIONS, indirect=True) +def test_persist_to(metadata_store): # noqa: F811 + cache_key = _mock_cache_key() + data_version = "foo-a" + run_id = "test-run-id" + in_memory_metadata_store = InMemoryMetadataStore() + + # set values in-memory + in_memory_metadata_store.initialize(run_id=run_id) + in_memory_metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + + # values exist in memory, but not in destination + assert in_memory_metadata_store.get(cache_key) == data_version + assert metadata_store.get(cache_key) is None + + # persist to destination + in_memory_metadata_store.persist_to(metadata_store) + assert metadata_store.get(cache_key) == data_version + assert in_memory_metadata_store.size == metadata_store.size + assert in_memory_metadata_store.get_run_ids() == metadata_store.get_run_ids() + + +@pytest.mark.parametrize("metadata_store", PERSISTENT_IMPLEMENTATIONS, indirect=True) +def test_load_from(metadata_store): # noqa: F811 + cache_key = _mock_cache_key() + data_version = "foo-a" + run_id = "test-run-id" + + # set values in source + metadata_store.initialize(run_id=run_id) + metadata_store.set( + cache_key=cache_key, + data_version=data_version, + run_id=run_id, + ) + + # values exist in source + assert metadata_store.get(cache_key) == data_version + + in_memory_metadata_store = InMemoryMetadataStore.load_from(metadata_store) + assert in_memory_metadata_store.get(cache_key) == data_version + assert in_memory_metadata_store.size == metadata_store.size + assert in_memory_metadata_store.get_run_ids() == metadata_store.get_run_ids() diff --git a/tests/caching/result_store/__init__.py b/tests/caching/result_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/caching/result_store/test_base.py b/tests/caching/result_store/test_base.py new file mode 100644 index 000000000..c7b461e4f --- /dev/null +++ b/tests/caching/result_store/test_base.py @@ -0,0 +1,28 @@ +import pytest + +from hamilton.caching.stores.file import FileResultStore +from hamilton.caching.stores.memory import InMemoryResultStore + + +def _instantiate_result_store(result_store_cls, tmp_path): + if result_store_cls == FileResultStore: + return FileResultStore(path=tmp_path) + elif result_store_cls == InMemoryResultStore: + return InMemoryResultStore() + else: + raise ValueError( + f"Class `{result_store_cls}` isn't defined in `_instantiate_metadata_store()`" + ) + + +@pytest.fixture +def result_store(request, tmp_path): + result_store_cls = request.param + result_store = _instantiate_result_store(result_store_cls, tmp_path) + + yield result_store + + result_store.delete_all() + + +# NOTE add tests that check properties shared across result store implementations below diff --git a/tests/caching/result_store/test_file.py b/tests/caching/result_store/test_file.py new file mode 100644 index 000000000..708a6a527 --- /dev/null +++ b/tests/caching/result_store/test_file.py @@ -0,0 +1,112 @@ +import pathlib +import pickle +import shutil + +import pytest + +from hamilton.caching import fingerprinting +from hamilton.caching.stores.base import search_data_adapter_registry +from hamilton.caching.stores.file import FileResultStore + + +def _store_size(result_store: FileResultStore) -> int: + return len([p for p in result_store.path.iterdir()]) + + +@pytest.fixture +def file_store(tmp_path): + store = FileResultStore(path=tmp_path) + assert _store_size(store) == 0 + yield store + + shutil.rmtree(tmp_path) + + +def test_set(file_store): + data_version = "foo" + + file_store.set(data_version=data_version, result="bar") + + assert pathlib.Path(file_store.path, data_version).exists() + assert _store_size(file_store) == 1 + + +def test_exists(file_store): + data_version = "foo" + assert file_store.exists(data_version) == pathlib.Path(file_store.path, data_version).exists() + + file_store.set(data_version=data_version, result="bar") + + assert file_store.exists(data_version) == pathlib.Path(file_store.path, data_version).exists() + + +def test_set_doesnt_produce_duplicates(file_store): + data_version = "foo" + assert not file_store.exists(data_version) + + file_store.set(data_version=data_version, result="bar") + file_store.set(data_version=data_version, result="bar") + + assert file_store.exists(data_version) + assert _store_size(file_store) == 1 + + +def test_get(file_store): + data_version = "foo" + result = "bar" + pathlib.Path(file_store.path, data_version).open("wb").write(pickle.dumps(result)) + assert file_store.exists(data_version) + + retrieved_value = file_store.get(data_version) + + assert retrieved_value + assert result == retrieved_value + assert _store_size(file_store) == 1 + + +def test_get_missing_result_is_none(file_store): + result = file_store.get("foo") + assert result is None + + +def test_delete(file_store): + data_version = "foo" + file_store.set(data_version, "bar") + assert pathlib.Path(file_store.path, data_version).exists() + assert _store_size(file_store) == 1 + + file_store.delete(data_version) + + assert not pathlib.Path(file_store.path, data_version).exists() + assert _store_size(file_store) == 0 + + +def test_delete_all(file_store): + file_store.set("foo", "foo") + file_store.set("bar", "bar") + assert _store_size(file_store) == 2 + + file_store.delete_all() + + assert _store_size(file_store) == 0 + + +@pytest.mark.parametrize( + "format,value", + [ + ("json", {"key1": "value1", "key2": 2}), + ("pickle", ("value1", "value2", "value3")), + ], +) +def test_save_and_load_materializer(format, value, file_store): + saver_cls, loader_cls = search_data_adapter_registry(name=format, type_=type(value)) + data_version = "foo" + materialized_path = file_store._materialized_path(data_version, saver_cls) + + file_store.set( + data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls + ) + retrieved_value = file_store.get(data_version) + + assert materialized_path.exists() + assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value) diff --git a/tests/caching/result_store/test_memory.py b/tests/caching/result_store/test_memory.py new file mode 100644 index 000000000..7d7fc87e3 --- /dev/null +++ b/tests/caching/result_store/test_memory.py @@ -0,0 +1,131 @@ +import pytest + +from hamilton.caching.stores.base import StoredResult +from hamilton.caching.stores.file import FileResultStore +from hamilton.caching.stores.memory import InMemoryResultStore + +# `result_store` is imported but not directly used because it's +# a pytest fixture automatically provided to tests +from .test_base import result_store # noqa: F401 + +# implementations that in-memory result store can `.persist_to()` and `.load_from()` +PERSISTENT_IMPLEMENTATIONS = [FileResultStore] + + +def _store_size(memory_store: InMemoryResultStore) -> int: + return len(memory_store._results) + + +@pytest.fixture +def memory_store(): + store = InMemoryResultStore() + assert _store_size(store) == 0 + yield store + + +def test_set(memory_store): + data_version = "foo" + + memory_store.set(data_version=data_version, result="bar") + + assert memory_store._results[data_version].value == "bar" + assert _store_size(memory_store) == 1 + + +def test_exists(memory_store): + data_version = "foo" + assert memory_store.exists(data_version) is False + + memory_store.set(data_version=data_version, result="bar") + + assert memory_store.exists(data_version) is True + + +def test_set_doesnt_produce_duplicates(memory_store): + data_version = "foo" + assert not memory_store.exists(data_version) + + memory_store.set(data_version=data_version, result="bar") + memory_store.set(data_version=data_version, result="bar") + + assert memory_store.exists(data_version) + assert _store_size(memory_store) == 1 + + +def test_get(memory_store): + data_version = "foo" + result = StoredResult(value="bar") + memory_store._results[data_version] = result + assert memory_store.exists(data_version) + + retrieved_value = memory_store.get(data_version) + + assert retrieved_value is not None + assert result.value == retrieved_value + assert _store_size(memory_store) == 1 + + +def test_get_missing_result_is_none(memory_store): + result = memory_store.get("foo") + assert result is None + + +def test_delete(memory_store): + data_version = "foo" + memory_store._results[data_version] = StoredResult(value="bar") + assert _store_size(memory_store) == 1 + + memory_store.delete(data_version) + + assert memory_store._results.get(data_version) is None + assert _store_size(memory_store) == 0 + + +def test_delete_all(memory_store): + memory_store._results["foo"] = "foo" + memory_store._results["bar"] = "bar" + assert _store_size(memory_store) == 2 + + memory_store.delete_all() + + assert _store_size(memory_store) == 0 + + +@pytest.mark.parametrize("result_store", PERSISTENT_IMPLEMENTATIONS, indirect=True) +def test_persist_to(result_store, memory_store): # noqa: F811 + data_version = "foo" + result = "bar" + + # set values in-memory + memory_store.set(data_version=data_version, result=result) + + # values exist in memory, but not in destination + assert memory_store.get(data_version) == result + assert result_store.get(data_version) is None + + # persist to destination + memory_store.persist_to(result_store) + assert memory_store.get(data_version) == result_store.get(data_version) + + +@pytest.mark.parametrize("result_store", PERSISTENT_IMPLEMENTATIONS, indirect=True) +def test_load_from(result_store): # noqa: F811 + data_version = "foo" + result = "bar" + + # set values in source + result_store.set(data_version=data_version, result=result) + + # values exist in source + assert result_store.get(data_version) == result + + memory_store = InMemoryResultStore.load_from( + result_store=result_store, data_versions=[data_version] + ) + assert memory_store.get(data_version) == result_store.get(data_version) + + +def test_load_from_must_have_metadata_store_or_data_versions(tmp_path): + file_result_store = FileResultStore(tmp_path) + with pytest.raises(ValueError): + InMemoryResultStore.load_from(result_store=file_result_store) diff --git a/tests/caching/test_integration.py b/tests/caching/test_integration.py index cb14c33e1..61315ce34 100644 --- a/tests/caching/test_integration.py +++ b/tests/caching/test_integration.py @@ -1,3 +1,4 @@ +import itertools from typing import List import pandas as pd @@ -5,6 +6,9 @@ from hamilton import ad_hoc_utils, driver from hamilton.caching.adapter import CachingEventType, HamiltonCacheAdapter +from hamilton.caching.stores.file import FileResultStore +from hamilton.caching.stores.memory import InMemoryMetadataStore, InMemoryResultStore +from hamilton.caching.stores.sqlite import SQLiteMetadataStore from hamilton.execution.executors import ( MultiProcessingExecutor, MultiThreadingExecutor, @@ -12,9 +16,30 @@ ) from hamilton.function_modifiers import cache as cache_decorator +# `metadata_store` and `result_store` are imported but not directly used because they +# are pytest fixtures automatically provided to tests +from .metadata_store.test_base import metadata_store # noqa: F401 +from .result_store.test_base import result_store # noqa: F401 from tests.resources.dynamic_parallelism import parallel_linear_basic, parallelism_with_caching +def _instantiate_executor(executor_cls): + if executor_cls == SynchronousLocalTaskExecutor: + return SynchronousLocalTaskExecutor() + elif executor_cls == MultiProcessingExecutor: + return MultiProcessingExecutor(max_tasks=10) + elif executor_cls == MultiThreadingExecutor: + return MultiThreadingExecutor(max_tasks=10) + else: + raise ValueError(f"Class `{executor_cls}` isn't defined in `_instantiate_executor()`") + + +@pytest.fixture +def executor(request): + executor_cls = request.param + return _instantiate_executor(executor_cls) + + @pytest.fixture def dr(request, tmp_path): module = request.param @@ -483,19 +508,34 @@ def foo() -> dict: assert result[node_name] == retrieved_result +EXECUTORS_AND_STORES_CONFIGURATIONS = list( + itertools.product( + [SynchronousLocalTaskExecutor, MultiThreadingExecutor, MultiProcessingExecutor], + [SQLiteMetadataStore], + [FileResultStore], + ) +) + +# InMemory stores can't be used with multiprocessing because they don't share memory. +IN_MEMORY_CONFIGURATIONS = list( + itertools.product( + [SynchronousLocalTaskExecutor, MultiThreadingExecutor], + [InMemoryMetadataStore, SQLiteMetadataStore], + [InMemoryResultStore, FileResultStore], + ) +) + +EXECUTORS_AND_STORES_CONFIGURATIONS += IN_MEMORY_CONFIGURATIONS + + @pytest.mark.parametrize( - "executor", - [ - SynchronousLocalTaskExecutor(), - MultiProcessingExecutor(max_tasks=10), - MultiThreadingExecutor(max_tasks=10), - ], + "executor,metadata_store,result_store", EXECUTORS_AND_STORES_CONFIGURATIONS, indirect=True ) -def test_parallel_synchronous_step_by_step(tmp_path, executor): +def test_parallel_synchronous_step_by_step(executor, metadata_store, result_store): # noqa: F811 dr = ( driver.Builder() .with_modules(parallel_linear_basic) - .with_cache(path=tmp_path) + .with_cache(metadata_store=metadata_store, result_store=result_store) .enable_dynamic_execution(allow_experimental_mode=True) .with_remote_executor(executor) .build() @@ -559,11 +599,8 @@ def test_parallel_synchronous_step_by_step(tmp_path, executor): @pytest.mark.parametrize( "executor", - [ - SynchronousLocalTaskExecutor(), - MultiProcessingExecutor(max_tasks=10), - MultiThreadingExecutor(max_tasks=10), - ], + [SynchronousLocalTaskExecutor, MultiProcessingExecutor, MultiThreadingExecutor], + indirect=True, ) def test_materialize_parallel_branches(tmp_path, executor): # NOTE the module can't be defined here because multithreading requires functions to be top-level. @@ -596,7 +633,7 @@ def test_materialize_parallel_branches(tmp_path, executor): ) -def test_consistent_cache_key_with_or_without_defaut_parameter(tmp_path): +def test_consistent_cache_key_with_or_without_default_parameter(tmp_path): def foo(external_dep: int = 3) -> int: return external_dep + 1 diff --git a/tests/caching/test_metadata_store.py b/tests/caching/test_metadata_store.py deleted file mode 100644 index 156e44b3e..000000000 --- a/tests/caching/test_metadata_store.py +++ /dev/null @@ -1,155 +0,0 @@ -import pytest - -from hamilton.caching.cache_key import create_cache_key -from hamilton.caching.stores.sqlite import SQLiteMetadataStore - - -@pytest.fixture -def metadata_store(request, tmp_path): - metdata_store_cls = request.param - metadata_store = metdata_store_cls(path=tmp_path) - run_id = "test-run-id" - try: - metadata_store.initialize(run_id) - except BaseException: - pass - - yield metadata_store - - metadata_store.delete_all() - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_initialize_empty(metadata_store): - assert metadata_store.size == 0 - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_not_empty_after_set(metadata_store): - code_version = "FOO-1" - data_version = "foo-a" - node_name = "foo" - cache_key = create_cache_key( - node_name=node_name, code_version=code_version, dependencies_data_versions={} - ) - - metadata_store.set( - cache_key=cache_key, - node_name=node_name, - code_version=code_version, - data_version=data_version, - run_id="...", - ) - - assert metadata_store.size > 0 - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_set_doesnt_produce_duplicates(metadata_store): - code_version = "FOO-1" - data_version = "foo-a" - node_name = "foo" - cache_key = create_cache_key( - node_name=node_name, code_version=code_version, dependencies_data_versions={} - ) - metadata_store.set( - cache_key=cache_key, - node_name=node_name, - code_version=code_version, - data_version=data_version, - run_id="...", - ) - assert metadata_store.size == 1 - - metadata_store.set( - cache_key=cache_key, - node_name=node_name, - code_version=code_version, - data_version=data_version, - run_id="...", - ) - assert metadata_store.size == 1 - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_get_miss_returns_none(metadata_store): - cache_key = create_cache_key( - node_name="foo", code_version="FOO-1", dependencies_data_versions={"bar": "bar-a"} - ) - data_version = metadata_store.get(cache_key=cache_key) - assert data_version is None - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_set_get_without_dependencies(metadata_store): - code_version = "FOO-1" - data_version = "foo-a" - node_name = "foo" - cache_key = create_cache_key( - node_name=node_name, code_version=code_version, dependencies_data_versions={} - ) - metadata_store.set( - cache_key=cache_key, - node_name=node_name, - code_version=code_version, - data_version=data_version, - run_id="...", - ) - retrieved_data_version = metadata_store.get(cache_key=cache_key) - - assert retrieved_data_version == data_version - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_get_run_ids_returns_ordered_list(metadata_store): - pre_run_ids = metadata_store.get_run_ids() - assert pre_run_ids == ["test-run-id"] # this is from the fixture - - metadata_store.initialize(run_id="foo") - metadata_store.initialize(run_id="bar") - metadata_store.initialize(run_id="baz") - - post_run_ids = metadata_store.get_run_ids() - assert post_run_ids == ["test-run-id", "foo", "bar", "baz"] - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_get_run_results_include_cache_key_and_data_version(metadata_store): - run_id = "test-run-id" - metadata_store.set( - cache_key="foo", - data_version="1", - run_id=run_id, - node_name="a", # kwarg specific to SQLiteMetadataStore - code_version="b", # kwarg specific to SQLiteMetadataStore - ) - metadata_store.set( - cache_key="bar", - data_version="2", - run_id=run_id, - node_name="a", # kwarg specific to SQLiteMetadataStore - code_version="b", # kwarg specific to SQLiteMetadataStore - ) - - run_info = metadata_store.get_run(run_id=run_id) - - assert isinstance(run_info, list) - assert len(run_info) == 2 - assert isinstance(run_info[1], dict) - assert run_info[0]["cache_key"] == "foo" - assert run_info[0]["data_version"] == "1" - assert run_info[1]["cache_key"] == "bar" - assert run_info[1]["data_version"] == "2" - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store): - metadata_store.initialize(run_id="foo") - run_info = metadata_store.get_run(run_id="foo") - assert run_info == [] - - -@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) -def test_get_run_raises_error_if_run_id_not_found(metadata_store): - with pytest.raises(IndexError): - metadata_store.get_run(run_id="foo")