Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: caching SQLiteMetadataStore.get_run_ids() #1205

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions hamilton/caching/stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def set(self, cache_key: str, data_version: str, **kwargs) -> Optional[Any]:
"""

@abc.abstractmethod
def get(self, cache_key: str) -> Optional[str]:
def get(self, cache_key: str, **kwargs) -> Optional[str]:
"""Try to retrieve ``data_version`` keyed by ``cache_key``.
If retrieval misses return ``None``.
"""
Expand All @@ -118,15 +118,19 @@ def exists(self, cache_key: str) -> bool:
def get_run_ids(self) -> Sequence[str]:
"""Return a list of run ids, sorted from oldest to newest start time.
A ``run_id`` is registered when the metadata_store ``.initialize()`` is called.

NOTE because of race conditions, the order could theoretically differ from the
order stored on the SmartCacheAdapter `._run_ids` attribute.
"""

@abc.abstractmethod
def get_run(self, run_id: str) -> Any:
"""Return all the metadata associated with a run.
The metadata content may differ across MetadataStore implementations
def get_run(self, run_id: str) -> Sequence[dict]:
"""Return a list of node metadata associated with a run.
For each node, the metadata should include:
- ``cache_key`` (created or used)
- ``data_version``
This is to allow users to manually query the MetadataStore or ResultStore.
skrawcz marked this conversation as resolved.
Show resolved Hide resolved

Decoding the ``cache_key`` gives the ``node_name``, ``code_version``, and
``dependencies_data_versions``. Individual implementations may add more
information or decode the ``cache_key`` before returning metadata.
"""

@property
Expand Down
82 changes: 61 additions & 21 deletions hamilton/caching/stores/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import threading
from typing import List, Optional

from hamilton.caching.cache_key import decode_key
from hamilton.caching.stores.base import MetadataStore


Expand All @@ -19,14 +20,14 @@ def __init__(

self._thread_local = threading.local()

def _get_connection(self):
def _get_connection(self) -> sqlite3.Connection:
if not hasattr(self._thread_local, "connection"):
self._thread_local.connection = sqlite3.connect(
str(self._path), check_same_thread=False, **self.connection_kwargs
)
return self._thread_local.connection

def _close_connection(self):
def _close_connection(self) -> None:
if hasattr(self._thread_local, "connection"):
self._thread_local.connection.close()
del self._thread_local.connection
Expand Down Expand Up @@ -76,9 +77,9 @@ def _create_tables_if_not_exists(self):
"""\
CREATE TABLE IF NOT EXISTS cache_metadata (
cache_key TEXT PRIMARY KEY,
data_version TEXT NOT NULL,
node_name TEXT NOT NULL,
code_version TEXT NOT NULL,
data_version TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,

FOREIGN KEY (cache_key) REFERENCES history(cache_key)
Expand Down Expand Up @@ -106,13 +107,21 @@ def set(
self,
*,
cache_key: str,
node_name: str,
code_version: str,
data_version: str,
run_id: str,
node_name: str = None,
code_version: str = None,
Comment on lines +112 to +113
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be backwards incompatible. But if we're the only ones using this, then all good.

Copy link
Collaborator Author

@zilto zilto Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not because the HamiltonCacheAdapter is the only object passing values here. As shown in the updated logic of SQLiteMetadataStore.set() it will resolve the values of node_name and code_version from the cache_key so these values will never be None (for instance, the SQL table still enforce NOT NULL constraints)

For instance, the MetadataStore base class doesn't expose node_name and code_version as arguments because other implementations might no provide them directly. This information is redundant with the content of cache_key.

The SQLiteMetadataStore implementation uses node_name and code_version explicitly to allow indexing and cache management (e.g., delete all instance of this node). These values could be derived by decoding cache_key but that would be a wasteful operation

Copy link
Collaborator

@skrawcz skrawcz Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's more that if people were not using keyword args it would be break due to the change in argument order; fix is easy though as you just update all callers to the new order.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I didn't see any update to callers of this function hence why I asked)

**kwargs,
) -> None:
cur = self.connection.cursor()

# if the caller of ``.set()`` directly provides the ``node_name`` and ``code_version``,
# we can skip the decoding step.
if (node_name is None) or (code_version is None):
decoded_key = decode_key(cache_key)
node_name = decoded_key["node_name"]
code_version = decoded_key["code_version"]

cur.execute("INSERT INTO history (cache_key, run_id) VALUES (?, ?)", (cache_key, run_id))
cur.execute(
"""\
Expand Down Expand Up @@ -150,7 +159,7 @@ def delete(self, cache_key: str) -> None:
cur.execute("DELETE FROM cache_metadata WHERE cache_key = ?", (cache_key,))
self.connection.commit()

def delete_all(self):
def delete_all(self) -> None:
"""Delete all existing tables from the database"""
cur = self.connection.cursor()

Expand All @@ -170,35 +179,66 @@ def exists(self, cache_key: str) -> bool:
return result is not None

def get_run_ids(self) -> List[str]:
"""Return a list of run ids, sorted from oldest to newest start time."""
cur = self.connection.cursor()
cur.execute("SELECT run_id FROM history ORDER BY id")
cur.execute("SELECT run_id FROM run_ids ORDER BY id")
result = cur.fetchall()

if result is None:
raise IndexError("No `run_id` found. Table `history` is empty.")
return [r[0] for r in result]

return result[0]
def _run_exists(self, run_id: str) -> bool:
"""Returns True if a run was initialized with ``run_id``, even
if the run recorded no node executions.
"""
cur = self.connection.cursor()
cur.execute(
"""\
SELECT EXISTS(
SELECT 1
FROM run_ids
WHERE run_id = ?
)
""",
(run_id,),
)
result = cur.fetchone()
# SELECT EXISTS returns 1 for True, i.e., `run_id` is found
return result[0] == 1

def get_run(self, run_id: str) -> List[dict]:
"""Return all the metadata associated with a run."""
"""Return a list of node metadata associated with a run.

:param run_id: ID of the run to retrieve
:return: List of node metadata which includes ``cache_key``, ``data_version``,
``node_name``, and ``code_version``. The list can be empty if a run was initialized
but no nodes were executed.

Raises an ``IndexError`` if the ``run_id`` is not found in metadata store.
"""
cur = self.connection.cursor()
if self._run_exists(run_id) is False:
raise IndexError(f"`run_id` not found in table `run_ids`: {run_id}")

cur.execute(
"""\
SELECT
cache_metadata.cache_key,
cache_metadata.data_version,
cache_metadata.node_name,
cache_metadata.code_version,
cache_metadata.data_version
FROM (SELECT * FROM history WHERE history.run_id = ?) AS run_history
JOIN cache_metadata ON run_history.cache_key = cache_metadata.cache_key
cache_metadata.code_version
FROM history
JOIN cache_metadata ON history.cache_key = cache_metadata.cache_key
WHERE history.run_id = ?
""",
(run_id,),
)
results = cur.fetchall()

if results is None:
raise IndexError(f"`run_id` not found in table `history`: {run_id}")

return [
dict(node_name=node_name, code_version=code_version, data_version=data_version)
for node_name, code_version, data_version in results
dict(
cache_key=cache_key,
data_version=data_version,
node_name=node_name,
code_version=code_version,
)
for cache_key, data_version, node_name, code_version in results
]
55 changes: 55 additions & 0 deletions tests/caching/test_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,58 @@ def test_set_get_without_dependencies(metadata_store):
retrieved_data_version = metadata_store.get(cache_key=cache_key)

assert retrieved_data_version == data_version


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_ids_returns_ordered_list(metadata_store):
pre_run_ids = metadata_store.get_run_ids()
assert pre_run_ids == ["test-run-id"] # this is from the fixture

metadata_store.initialize(run_id="foo")
metadata_store.initialize(run_id="bar")
metadata_store.initialize(run_id="baz")

post_run_ids = metadata_store.get_run_ids()
assert post_run_ids == ["test-run-id", "foo", "bar", "baz"]


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_results_include_cache_key_and_data_version(metadata_store):
run_id = "test-run-id"
metadata_store.set(
cache_key="foo",
data_version="1",
run_id=run_id,
node_name="a", # kwarg specific to SQLiteMetadataStore
code_version="b", # kwarg specific to SQLiteMetadataStore
)
metadata_store.set(
cache_key="bar",
data_version="2",
run_id=run_id,
node_name="a", # kwarg specific to SQLiteMetadataStore
code_version="b", # kwarg specific to SQLiteMetadataStore
)

run_info = metadata_store.get_run(run_id=run_id)

assert isinstance(run_info, list)
assert len(run_info) == 2
assert isinstance(run_info[1], dict)
assert run_info[0]["cache_key"] == "foo"
assert run_info[0]["data_version"] == "1"
assert run_info[1]["cache_key"] == "bar"
assert run_info[1]["data_version"] == "2"


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_returns_empty_list_if_run_started_but_no_execution_recorded(metadata_store):
metadata_store.initialize(run_id="foo")
run_info = metadata_store.get_run(run_id="foo")
assert run_info == []


@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True)
def test_get_run_raises_error_if_run_id_not_found(metadata_store):
with pytest.raises(IndexError):
metadata_store.get_run(run_id="foo")