Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added default CacheAdapter lifecycle hook #766

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 109 additions & 3 deletions hamilton/lifecycle/default.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""A selection of default lifeycle hooks/methods that come with Hamilton. These carry no additional requirements"""

import hashlib
import logging
import pdb
import pickle
import pprint
import random
import shelve
import time
from typing import Any, Callable, Dict, List, Optional, Union

from hamilton import htypes
from hamilton.lifecycle import NodeExecutionHook
from hamilton.lifecycle.api import NodeExecutionMethod
from hamilton import graph_types, htypes
from hamilton.graph_types import HamiltonGraph
from hamilton.lifecycle import GraphExecutionHook, NodeExecutionHook, NodeExecutionMethod

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -313,6 +316,109 @@ def run_after_node_execution(
pdb.set_trace()


class CacheAdapter(NodeExecutionHook, NodeExecutionMethod, GraphExecutionHook):
"""Class to cache node results on disk with a key based on the node code implementation and inputs.
Following runs with the same key can load node results and skip computation.

The cache `_nodes_history` entry returns an append-only list of results added to the cache.
e.g., the last value in the list `cache["_node_history"][node_name]` is the most recent cached node.

Notes:
- It uses the stdlib `shelve` module and the pickle format, which makes results dependent
on the Python version. Use materialization for persistent results
- There are no utility to manage cache size so you'll have to delete it periodically. Look
at the diskcache plugin for Hamilton `hamilton.plugins.h_diskcache` for better cache management.
"""

nodes_history_key: str = "_nodes_history"

def __init__(
self, cache_vars: Union[List[str], None] = None, cache_path: str = "./hamilton-cache"
):
"""Initialize the cache

:param cache_vars: List of nodes for which to store/load results. Passing None will use the cache
for all nodes. Default is None.
:param cache_path: File path to the cache. The file name doesn't need an extension.
"""
self.cache_vars = cache_vars if cache_vars else []
self.cache_path = cache_path
self.cache = shelve.open(self.cache_path)
self.nodes_history: Dict[str, List[str]] = self.cache.get(
key=CacheAdapter.nodes_history_key, default=dict()
)
self.used_nodes_hash: Dict[str, str] = dict()

def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs):
"""Set `cache_vars` to all nodes if received None during `__init__`"""
if self.cache_vars == []:
self.cache_vars = [n.name for n in graph.nodes]

def run_to_execute_node(
self, *, node_name: str, node_callable: Any, node_kwargs: Dict[str, Any], **kwargs
):
"""Create cache key based on node callable hash (equiv. to HamiltonNode.version) and
the node inputs (`node_kwargs`).If key in cache (cache hit), load result; else (cache miss),
compute the node and append node name to `used_nodes_hash`.

Note:
- the callable hash is stored in `used_nodes_hash` because it's required to create the
key in `run_after_node_execution` and the callable won't be accessible to recompute it
"""
if node_name not in self.cache_vars:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For afterwards -- we may want to make this a set? Odds are high this is not a slow operation, but it could get O(n**2) which we don't want.

return node_callable(**node_kwargs)

node_hash = graph_types.hash_source_code(node_callable, strip=True)
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)

from_cache = self.cache.get(cache_key, None)
if from_cache is not None:
return from_cache

self.used_nodes_hash[node_name] = node_hash
self.nodes_history[node_name] = self.nodes_history.get(node_name, []) + [node_hash]
return node_callable(**node_kwargs)

def run_after_node_execution(
self, *, node_name: str, node_kwargs: Dict[str, Any], result: Any, **kwargs
):
"""If `run_to_execute_node` was a cache miss (hash stored in `used_nodes_hash`),
store the computed result in cache
"""
if node_name not in self.cache_vars:
return

node_hash = self.used_nodes_hash.get(node_name)
if node_hash is None:
return

cache_key = CacheAdapter.create_key(node_hash, node_kwargs)
self.cache[cache_key] = result

def run_after_graph_execution(self, *args, **kwargs):
"""After completing execution, overwrite nodes_history_key in cache and close"""
# TODO updating `nodes_history` at graph completion instead of after node execution
# means a desync is possible if the graph fails. Could lead to missing keys in
# `nodes_history`
self.cache[CacheAdapter.nodes_history_key] = self.nodes_history
self.cache.close()

def run_before_node_execution(self, *args, **kwargs):
"""Placeholder required to subclass `NodeExecutionMethod`"""
pass

@staticmethod
def create_key(node_hash: str, node_inputs: Dict[str, Any]) -> str:
"""Pickle objects into bytes then get their hash value"""
digest = hashlib.sha256()
digest.update(node_hash.encode())

for ins in node_inputs.values():
digest.update(pickle.dumps(ins))

return digest.hexdigest()


def wait_random(mean: float, stddev: float):
sleep_time = random.gauss(mu=mean, sigma=stddev)
if sleep_time < 0:
Expand Down
136 changes: 136 additions & 0 deletions tests/lifecycle/test_cache_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import inspect
import pathlib
import shelve

import pytest

from hamilton import graph_types, node
from hamilton.lifecycle.default import CacheAdapter


def _callable_to_node(callable) -> node.Node:
return node.Node(
name=callable.__name__,
typ=inspect.signature(callable).return_annotation,
callabl=callable,
)


@pytest.fixture()
def hook(tmp_path: pathlib.Path):
return CacheAdapter(cache_path=str(tmp_path.resolve()))


@pytest.fixture()
def node_a():
"""Default function implementation"""

def A(external_input: int) -> int:
return external_input % 7

return _callable_to_node(A)


@pytest.fixture()
def node_a_body():
"""The function A() has modulo 5 instead of 7"""

def A(external_input: int) -> int:
return external_input % 5

return _callable_to_node(A)


@pytest.fixture()
def node_a_docstring():
"""The function A() has a docstring"""

def A(external_input: int) -> int:
"""This one has a docstring"""
return external_input % 7

return _callable_to_node(A)


def test_set_result(hook: CacheAdapter, node_a: node.Node):
"""Hook sets value and assert value in cache"""
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_kwargs = dict(external_input=7)
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)
result = 2

hook.cache_vars = [node_a.name]
# used_nodes_hash would be set by run_to_execute() hook
hook.used_nodes_hash[node_a.name] = node_hash
hook.run_after_node_execution(
node_name=node_a.name,
node_kwargs=node_kwargs,
result=result,
)

# run_to_execute_node() hook would get cache
assert hook.cache.get(key=cache_key) == result


def test_get_result(hook: CacheAdapter, node_a: node.Node):
"""Hooks get value and assert cache hit"""
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_kwargs = dict(external_input=7)
result = 2
cache_key = CacheAdapter.create_key(node_hash, node_kwargs)

hook.cache_vars = [node_a.name]
# run_after_node_execution() would set cache
hook.cache[cache_key] = result
retrieved = hook.run_to_execute_node(
node_name=node_a.name,
node_kwargs=node_kwargs,
node_callable=node_a.callable,
)

assert retrieved == result


def test_append_nodes_history(
hook: CacheAdapter,
node_a: node.Node,
node_a_body: node.Node,
):
"""Assert the CacheHook.nodes_history is growing;
doesn't check for commit to cache
"""
node_name = "A"
node_kwargs = dict(external_input=7)
hook.cache_vars = [node_a.name]

# node version 1
hook.used_nodes_hash[node_name] = graph_types.hash_source_code(node_a.callable, strip=True)
hook.run_to_execute_node(
node_name=node_name,
node_kwargs=node_kwargs,
node_callable=node_a.callable,
)

# check history
assert len(hook.nodes_history.get(node_name, [])) == 1

# node version 2
hook.used_nodes_hash[node_name] = graph_types.hash_source_code(node_a_body.callable, strip=True)
hook.run_to_execute_node(
node_name=node_name,
node_kwargs=node_kwargs,
node_callable=node_a_body.callable,
)

assert len(hook.nodes_history.get(node_name, [])) == 2


def test_commit_nodes_history(hook: CacheAdapter):
"""Commit node history to cache"""
hook.nodes_history = dict(A=["hash_1", "hash_2"])

hook.run_after_graph_execution()

# need to reopen the hook cache
with shelve.open(hook.cache_path) as cache:
assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history