-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handles partials in source code hash #1116
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import random | ||
import shelve | ||
import time | ||
from functools import partial | ||
from typing import Any, Callable, Dict, List, Optional, Type, Union | ||
|
||
from hamilton import graph_types, htypes | ||
|
@@ -359,7 +360,7 @@ def __init__( | |
def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs): | ||
"""Set `cache_vars` to all nodes if received None during `__init__`""" | ||
self.cache = shelve.open(self.cache_path) | ||
if self.cache_vars == []: | ||
if len(self.cache_vars) == 0: | ||
self.cache_vars = [n.name for n in graph.nodes] | ||
|
||
def run_to_execute_node( | ||
|
@@ -376,7 +377,10 @@ def run_to_execute_node( | |
if node_name not in self.cache_vars: | ||
return node_callable(**node_kwargs) | ||
|
||
node_hash = graph_types.hash_source_code(node_callable, strip=True) | ||
source_of_node_callable = node_callable | ||
while isinstance(source_of_node_callable, partial): # handle partials | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool -- can make it into a function for clarity: |
||
source_of_node_callable = source_of_node_callable.func | ||
node_hash = graph_types.hash_source_code(source_of_node_callable, strip=True) | ||
zilto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cache_key = CacheAdapter.create_key(node_hash, node_kwargs) | ||
|
||
from_cache = self.cache.get(cache_key, None) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
import inspect | ||
import functools | ||
import pathlib | ||
import shelve | ||
|
||
|
@@ -8,12 +8,8 @@ | |
from hamilton.lifecycle.default import CacheAdapter | ||
|
||
|
||
def _callable_to_node(callable) -> node.Node: | ||
return node.Node( | ||
name=callable.__name__, | ||
typ=inspect.signature(callable).return_annotation, | ||
callabl=callable, | ||
) | ||
def _callable_to_node(callable, name=None) -> node.Node: | ||
return node.Node.from_fn(callable, name) | ||
|
||
|
||
@pytest.fixture() | ||
|
@@ -52,6 +48,37 @@ def A(external_input: int) -> int: | |
return _callable_to_node(A) | ||
|
||
|
||
@pytest.fixture() | ||
def node_a_partial(): | ||
"""The function A() is a partial""" | ||
|
||
def A(external_input: int, remainder: int) -> int: | ||
return external_input % remainder | ||
|
||
base_node: node.Node = _callable_to_node(A) | ||
|
||
A = functools.partial(A, remainder=7) | ||
base_node._callable = A | ||
del base_node.input_types["remainder"] | ||
return base_node | ||
|
||
|
||
@pytest.fixture() | ||
def node_a_nested_partial(): | ||
"""The function A() is a partial""" | ||
|
||
def A(external_input: int, remainder: int, extra: int) -> int: | ||
return external_input % remainder | ||
|
||
base_node: node.Node = _callable_to_node(A) | ||
A = functools.partial(A, remainder=7) | ||
A = functools.partial(A, extra=7) | ||
base_node._callable = A | ||
del base_node.input_types["remainder"] | ||
del base_node.input_types["extra"] | ||
return base_node | ||
|
||
|
||
def test_set_result(hook: CacheAdapter, node_a: node.Node): | ||
"""Hook sets value and assert value in cache""" | ||
node_hash = graph_types.hash_source_code(node_a.callable, strip=True) | ||
|
@@ -138,3 +165,49 @@ def test_commit_nodes_history(hook: CacheAdapter): | |
# need to reopen the hook cache | ||
with shelve.open(hook.cache_path) as cache: | ||
assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history | ||
|
||
|
||
def test_partial_handling(hook: CacheAdapter, node_a_partial: node.Node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah if the func-derivation is a function then this one is less necessary (although still nice to have), and we can test that/get more confidence. |
||
"""Tests partial functions are handled properly""" | ||
hook.cache_vars = [node_a_partial.name] | ||
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache | ||
node_kwargs = dict(external_input=7) | ||
result = hook.run_to_execute_node( | ||
node_name=node_a_partial.name, | ||
node_kwargs=node_kwargs, | ||
node_callable=node_a_partial.callable, | ||
) | ||
hook.run_after_node_execution( | ||
node_name=node_a_partial.name, | ||
node_kwargs=node_kwargs, | ||
result=result, | ||
) | ||
result2 = hook.run_to_execute_node( | ||
node_name=node_a_partial.name, | ||
node_kwargs=node_kwargs, | ||
node_callable=node_a_partial.callable, | ||
) | ||
assert result2 == result | ||
|
||
|
||
def test_nested_partial_handling(hook: CacheAdapter, node_a_nested_partial: node.Node): | ||
"""Tests nested partial functions are handled properly""" | ||
hook.cache_vars = [node_a_nested_partial.name] | ||
hook.run_before_graph_execution(graph=graph_types.HamiltonGraph([])) # needed to open cache | ||
node_kwargs = dict(external_input=7) | ||
result = hook.run_to_execute_node( | ||
node_name=node_a_nested_partial.name, | ||
node_kwargs=node_kwargs, | ||
node_callable=node_a_nested_partial.callable, | ||
) | ||
hook.run_after_node_execution( | ||
node_name=node_a_nested_partial.name, | ||
node_kwargs=node_kwargs, | ||
result=result, | ||
) | ||
result2 = hook.run_to_execute_node( | ||
node_name=node_a_nested_partial.name, | ||
node_kwargs=node_kwargs, | ||
node_callable=node_a_nested_partial.callable, | ||
) | ||
assert result2 == result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clever 😂