Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the node and dataflow versioning API #734

Merged
merged 13 commits into from
Mar 14, 2024
19 changes: 4 additions & 15 deletions hamilton/cli/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,23 @@ def get_git_reference(git_relative_path: Union[str, Path], git_reference: str) -

def version_hamilton_functions(module: ModuleType) -> Dict[str, str]:
"""Hash the source code of Hamilton functions from a module"""
from hamilton import graph_utils
from hamilton import graph_types, graph_utils

origins_version: Dict[str, str] = dict()

for origin_name, _ in graph_utils.find_functions(module):
origin_callable = getattr(module, origin_name)
origins_version[origin_name] = graph_utils.hash_source_code(origin_callable, strip=True)
origins_version[origin_name] = graph_types.hash_source_code(origin_callable, strip=True)

return origins_version


def hash_hamilton_nodes(dr: driver.Driver) -> Dict[str, str]:
"""Hash the source code of Hamilton functions from nodes in a Driver"""
from hamilton import graph_types, graph_utils
from hamilton import graph_types

graph = graph_types.HamiltonGraph.from_graph(dr.graph)

nodes_version = dict()
for n in graph.nodes:
# is None for config nodes
if n.originating_functions is None:
continue

node_origin = n.originating_functions[0]
origin_hash = graph_utils.hash_source_code(node_origin, strip=True)
nodes_version[n.name] = origin_hash

return nodes_version
return {n.name: n.version for n in graph.nodes}


def map_nodes_to_functions(dr: driver.Driver) -> Dict[str, str]:
Expand Down
87 changes: 87 additions & 0 deletions hamilton/graph_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Module for external-facing graph constructs. These help the user navigate/manage the graph as needed."""

import ast
import functools
import hashlib
import inspect
import typing
from dataclasses import dataclass
Expand All @@ -16,6 +19,70 @@
from hamilton import graph


def _remove_docs_and_comments(source: str) -> str:
"""Remove the docs and comments from a source code string.

The use of `ast.unparse()` requires Python 3.9

1. Parsing then unparsing the AST of the source code will
create a code object and convert it back to a string. In the
process, comments are stripped.

2. walk the AST to check if first element after `def` is a
docstring. If so, edit AST to skip the docstring

NOTE. The ast parsing will fail if `source` has syntax errors. For the
zilto marked this conversation as resolved.
Show resolved Hide resolved
majority of cases this is caught upstream (e.g., by calling `import`).
The foreseeable edge case is if `source` is the result of `inspect.getsource`
on a nested function, method, or callable where `def` isn't at column 0.
Standard usage of Hamilton requires users to define functions/nodes at the top
level of a module, and therefore no issues should arise.
"""
parsed = ast.parse(source)
for n in ast.walk(parsed):
if not isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue

if not len(n.body):
continue

# check if 1st node is a docstring
if not isinstance(n.body[0], ast.Expr):
continue

if not hasattr(n.body[0], "value") or not isinstance(n.body[0].value, ast.Str):
continue

# skip docstring
n.body = n.body[1:]

return ast.unparse(parsed)


def hash_source_code(source: typing.Union[str, typing.Callable], strip: bool = False) -> str:
"""Hashes the source code of a function (str).

The `strip` parameter requires Python 3.9

If strip, try to remove docs and comments from source code string. Since
they don't impact function behavior, they shouldn't influence the hash.
"""
if isinstance(source, typing.Callable):
source = inspect.getsource(source)

source = source.strip()

if strip:
try:
# could fail if source is indented code.
# see `remove_docs_and_comments` docstring for details.
source = _remove_docs_and_comments(source)
except Exception:
pass

return hashlib.sha256(source.encode()).hexdigest()


@dataclass
class HamiltonNode:
"""External facing API for hamilton Nodes. Having this as a dataclass allows us
Expand Down Expand Up @@ -45,6 +112,7 @@ def as_dict(self):
else None
),
"documentation": self.documentation,
"version": self.version,
}

@staticmethod
Expand Down Expand Up @@ -73,6 +141,15 @@ def from_node(n: node.Node) -> "HamiltonNode":
},
)

@functools.cached_property
def version(self) -> str:
"""Generate a hash of the node originating function source code.

The option `strip=True` means docstring and comments are ignored
when hashing the function.
"""
return hash_source_code(self.originating_functions[0], strip=True)
zilto marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
return f"{self.name}: {htypes.get_type_as_string(self.type)}"

Expand Down Expand Up @@ -101,3 +178,13 @@ def from_graph(fn_graph: "graph.FunctionGraph") -> "HamiltonGraph":
return HamiltonGraph(
nodes=[HamiltonNode.from_node(n) for n in fn_graph.nodes.values()],
)

@functools.cached_property
def version(self) -> str:
"""Generate a hash of the dataflow based on the collection of node hashes.

Node hashes are in a sorted list, then concatenated as a string before hashing.
To find differences between dataflows, you need to inspect the node level.
"""
sorted_node_versions = sorted([n.version for n in self.nodes])
zilto marked this conversation as resolved.
Show resolved Hide resolved
return hashlib.sha256(str(sorted_node_versions).encode()).hexdigest()
68 changes: 1 addition & 67 deletions hamilton/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import ast
import hashlib
import inspect
from types import ModuleType
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Tuple


def is_submodule(child: ModuleType, parent: ModuleType):
Expand All @@ -24,67 +22,3 @@ def valid_fn(fn):
)

return [f for f in inspect.getmembers(function_module, predicate=valid_fn)]


def hash_source_code(source: Union[str, Callable], strip: bool = False) -> str:
"""Hashes the source code of a function (str).

The `strip` parameter requires Python 3.9

If strip, try to remove docs and comments from source code string. Since
they don't impact function behavior, they shouldn't influence the hash.
"""
if isinstance(source, Callable):
source = inspect.getsource(source)

source = source.strip()

if strip:
try:
# could fail if source is indented code.
# see `remove_docs_and_comments` docstring for details.
source = remove_docs_and_comments(source)
except Exception:
pass

return hashlib.sha256(source.encode()).hexdigest()


def remove_docs_and_comments(source: str) -> str:
"""Remove the docs and comments from a source code string.

The use of `ast.unparse()` requires Python 3.9

1. Parsing then unparsing the AST of the source code will
create a code object and convert it back to a string. In the
process, comments are stripped.

2. walk the AST to check if first element after `def` is a
docstring. If so, edit AST to skip the docstring

NOTE. The ast parsing will fail if `source` has syntax errors. For the
majority of cases this is caught upstream (e.g., by calling `import`).
The foreseeable edge case is if `source` is the result of `inspect.getsource`
on a nested function, method, or callable where `def` isn't at column 0.
Standard usage of Hamilton requires users to define functions/nodes at the top
level of a module, and therefore no issues should arise.
"""
parsed = ast.parse(source)
for node in ast.walk(parsed):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue

if not len(node.body):
continue

# check if 1st node is a docstring
if not isinstance(node.body[0], ast.Expr):
continue

if not hasattr(node.body[0], "value") or not isinstance(node.body[0].value, ast.Str):
continue

# skip docstring
node.body = node.body[1:]

return ast.unparse(parsed)
6 changes: 3 additions & 3 deletions hamilton/plugins/h_diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import diskcache

from hamilton import driver, graph_types, graph_utils, lifecycle, node
from hamilton import driver, graph_types, lifecycle, node

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +26,7 @@ def evict_all_except(nodes_to_keep: Dict[str, node.Node], cache: diskcache.Cache

if node_name in nodes_to_keep.keys():
node_to_keep = nodes_to_keep[node_name]
hash_to_keep = graph_utils.hash_source_code(node_to_keep.callable, strip=True)
hash_to_keep = graph_types.hash_source_code(node_to_keep.callable, strip=True)
history.remove(hash_to_keep)
new_nodes_history[node_name] = [hash_to_keep]

Expand Down Expand Up @@ -99,7 +99,7 @@ def run_to_execute_node(
if node_name not in self.cache_vars:
return node_callable(**node_kwargs)

node_hash = graph_utils.hash_source_code(node_callable, strip=True)
node_hash = graph_types.hash_source_code(node_callable, strip=True)
self.used_nodes_hash[node_name] = node_hash
cache_key = (node_hash, *node_kwargs.values())

Expand Down
17 changes: 1 addition & 16 deletions hamilton/plugins/h_experiments/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,6 @@ def json_encoder(obj: Any):
return serialized


def graph_hash(graph: graph_types.HamiltonGraph) -> str:
"""Create a single hash (str) from the bytecode of all sorted functions"""
nodes_data = []
for node in graph.nodes:
source_code = ""
if node.originating_functions is not None:
source_code = inspect.getsource(node.originating_functions[0])

nodes_data.append(dict(name=node.name, source_code=source_code))

digest = hashlib.sha256()
digest.update(json.dumps(nodes_data, default=json_encoder, sort_keys=True).encode())
return digest.hexdigest()


@dataclass
class NodeImplementation:
name: str
Expand Down Expand Up @@ -147,7 +132,7 @@ def run_before_graph_execution(
**kwargs,
):
"""Store execution metadata: graph hash, inputs, overrides"""
self.graph_hash = graph_hash(graph)
self.graph_hash = graph.version

for node in graph.nodes:
if node.tags.get("module"):
Expand Down
18 changes: 9 additions & 9 deletions tests/plugins/test_h_diskcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from hamilton import ad_hoc_utils, driver, graph_utils, node
from hamilton import ad_hoc_utils, driver, graph_types, node
from hamilton.plugins import h_diskcache


Expand Down Expand Up @@ -53,7 +53,7 @@ def A(external_input: int) -> int:

def test_set_result(hook: h_diskcache.DiskCacheAdapter, node_a: node.Node):
"""Hook sets value and assert value in cache"""
node_hash = graph_utils.hash_source_code(node_a.callable, strip=True)
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_kwargs = dict(external_input=7)
cache_key = (node_hash, *node_kwargs.values())
result = 2
Expand All @@ -73,7 +73,7 @@ def test_set_result(hook: h_diskcache.DiskCacheAdapter, node_a: node.Node):

def test_get_result(hook: h_diskcache.DiskCacheAdapter, node_a: node.Node):
"""Hooks get value and assert cache hit"""
node_hash = graph_utils.hash_source_code(node_a.callable, strip=True)
node_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_kwargs = dict(external_input=7)
result = 2
cache_key = (node_hash, *node_kwargs.values())
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_append_nodes_history(
hook.cache_vars = [node_a.name]

# node version 1
hook.used_nodes_hash[node_name] = graph_utils.hash_source_code(node_a.callable, strip=True)
hook.used_nodes_hash[node_name] = graph_types.hash_source_code(node_a.callable, strip=True)
hook.run_to_execute_node(
node_name=node_name,
node_kwargs=node_kwargs,
Expand All @@ -116,7 +116,7 @@ def test_append_nodes_history(
assert len(hook.nodes_history.get(node_name, [])) == 1

# node version 2
hook.used_nodes_hash[node_name] = graph_utils.hash_source_code(node_a_body.callable, strip=True)
hook.used_nodes_hash[node_name] = graph_types.hash_source_code(node_a_body.callable, strip=True)
hook.run_to_execute_node(
node_name=node_name,
node_kwargs=node_kwargs,
Expand All @@ -141,8 +141,8 @@ def test_evict_all_except(
node_a_body: node.Node,
):
"""Check utility function to evict all except passed nodes"""
node_a_hash = graph_utils.hash_source_code(node_a.callable, strip=True)
node_a_body_hash = graph_utils.hash_source_code(node_a_body.callable, strip=True)
node_a_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_a_body_hash = graph_types.hash_source_code(node_a_body.callable, strip=True)
hook.cache[h_diskcache.DiskCacheAdapter.nodes_history_key] = dict(
A=[node_a_hash, node_a_body_hash]
)
Expand All @@ -159,8 +159,8 @@ def test_evict_from_driver(
node_a_body: node.Node,
):
"""Check utility function to evict all except driver"""
node_a_hash = graph_utils.hash_source_code(node_a.callable, strip=True)
node_a_body_hash = graph_utils.hash_source_code(node_a_body.callable, strip=True)
node_a_hash = graph_types.hash_source_code(node_a.callable, strip=True)
node_a_body_hash = graph_types.hash_source_code(node_a_body.callable, strip=True)
hook.cache[h_diskcache.DiskCacheAdapter.nodes_history_key] = dict(
A=[node_a_hash, node_a_body_hash]
)
Expand Down
Loading