Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handles partials in source code hash #1116

Merged
merged 4 commits into from
Sep 4, 2024
Merged

Handles partials in source code hash #1116

merged 4 commits into from
Sep 4, 2024

Conversation

skrawcz
Copy link
Collaborator

@skrawcz skrawcz commented Sep 2, 2024

This is a stop gap measure to handle
partials for the CacheAdapter.

Changes

  • adds check for partial functions; they have .func attribute on a callable.

How I tested this

This code now works:

from hamilton.function_modifiers import source, inject
from hamilton.ad_hoc_utils import create_temporary_module
import time

@inject(foo=source("foo"))
def worker(seconds: int, foo: int) -> int:
    print(f"start work {seconds}s")
    start = time.time()
    time.sleep(1)
    # yield {
    #     "job_name": f"sleep_{seconds}",
    #     "cmd": [f"sleep {str(seconds)}"],
    #     "block": True,
    # }
    print(f"foo: {foo}")
    end = time.time()
    print(f"end stop {end - start:.2f}s work")
    return seconds


if __name__ == "__main__":
    from hamilton import driver
    from hamilton.execution import executors
    from hamilton.lifecycle.default import CacheAdapter
    logic = create_temporary_module(
        worker,
    )

    dr = (
        driver.Builder()
        .with_modules(logic)
        .enable_dynamic_execution(allow_experimental_mode=True)
        .with_local_executor(executors.SynchronousLocalTaskExecutor())
        .with_remote_executor(
            executors.MultiThreadingExecutor(max_tasks=4)
        )
        # !!! error when uncomment this line !!!
        # TypeError: module, class, method, function, traceback, frame, or code object was expected, got partial
        .with_adapters(CacheAdapter(cache_path="cache"))
        .with_config({"foo": 42})
        .build()
    )
    start = time.time()
    dr.execute(["worker"], inputs={"seconds": 3})
    print(f"Time taken: {time.time() - start: .2f} seconds")

Notes

Checklist

  • PR has an informative and human-readable title (this will be pulled into the release notes)
  • Changes are limited to a single goal (no scope creep)
  • Code passed the pre-commit check & code is left cleaner/nicer than when first encountered.
  • Any change in functionality is tested
  • New functions are documented (with a description, list of inputs, and expected output)
  • Placeholder code is flagged / future TODOs are captured in comments
  • Project documentation has been updated if adding/changing functionality.

This is a stop gap measure to handle
partials for the CacheAdapter.

I put the change here rather than in the source hash function,
since for now it appears that this behavior is specific to the cache
adapter..
@skrawcz skrawcz requested a review from Roy-Kid September 2, 2024 18:24
Copy link

@Roy-Kid Roy-Kid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks nice. We can merge it.

@@ -376,7 +377,10 @@ def run_to_execute_node(
if node_name not in self.cache_vars:
return node_callable(**node_kwargs)

node_hash = graph_types.hash_source_code(node_callable, strip=True)
source_of_node_callable = node_callable
while isinstance(source_of_node_callable, partial): # handle partials
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clever 😂

Copy link
Collaborator

@elijahbenizzy elijahbenizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits, looks good

@@ -376,7 +377,10 @@ def run_to_execute_node(
if node_name not in self.cache_vars:
return node_callable(**node_kwargs)

node_hash = graph_types.hash_source_code(node_callable, strip=True)
source_of_node_callable = node_callable
while isinstance(source_of_node_callable, partial): # handle partials
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool -- can make it into a function for clarity: get_original_function(func: Callable) + test, but this is good

@@ -138,3 +165,49 @@ def test_commit_nodes_history(hook: CacheAdapter):
# need to reopen the hook cache
with shelve.open(hook.cache_path) as cache:
assert cache.get(CacheAdapter.nodes_history_key) == hook.nodes_history


def test_partial_handling(hook: CacheAdapter, node_a_partial: node.Node):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if the func-derivation is a function then this one is less necessary (although still nice to have), and we can test that/get more confidence.

@skrawcz skrawcz merged commit 6f2376d into main Sep 4, 2024
24 checks passed
@skrawcz skrawcz deleted the fix_cache_partial branch September 4, 2024 00:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants