From 808f448b4dcb9ad4e325b473dd372257b24e4296 Mon Sep 17 00:00:00 2001 From: zilto Date: Mon, 29 Apr 2024 14:38:02 -0400 Subject: [PATCH] fix code display for temporary modules --- hamilton/ad_hoc_utils.py | 6 +++--- ui/sdk/src/hamilton_sdk/driver.py | 36 +++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/hamilton/ad_hoc_utils.py b/hamilton/ad_hoc_utils.py index 773d078f0..5b2dab361 100644 --- a/hamilton/ad_hoc_utils.py +++ b/hamilton/ad_hoc_utils.py @@ -5,7 +5,7 @@ import types import uuid from types import ModuleType -from typing import Callable +from typing import Callable, Optional def _copy_func(f): @@ -60,9 +60,9 @@ def create_temporary_module(*functions: Callable, module_name: str = None) -> Mo return module -def module_from_source(source: str) -> ModuleType: +def module_from_source(source: str, module_name: Optional[str] = None) -> ModuleType: """Create a temporary module from source code""" - module_name = _generate_unique_temp_module_name() + module_name = module_name if module_name else _generate_unique_temp_module_name() module_object = ModuleType(module_name) code_object = compile(source, module_name, "exec") sys.modules[module_name] = module_object diff --git a/ui/sdk/src/hamilton_sdk/driver.py b/ui/sdk/src/hamilton_sdk/driver.py index 0087af798..7677ce89c 100644 --- a/ui/sdk/src/hamilton_sdk/driver.py +++ b/ui/sdk/src/hamilton_sdk/driver.py @@ -2,6 +2,7 @@ import hashlib import inspect import json +import linecache import logging import operator import os @@ -67,15 +68,19 @@ def _hash_module( f"attribute or it is None. This happens with lazy loaders." ) continue - # Check if the module is in the same top level package - if value.__package__ != module.__package__ and not value.__package__.startswith( - module.__package__ - ): - logger.debug( - f"Skipping hash for module {value.__name__} because it is in a different " - f"package {value.__package__} than {module.__package__}" - ) - continue + + # Modules imported in a temporary module have no `__package__` attribute + if module.__package__: + # Check if the module is in the same top level package + if value.__package__ != module.__package__ and not value.__package__.startswith( + module.__package__ + ): + logger.debug( + f"Skipping hash for module {value.__name__} because it is in a different " + f"package {value.__package__} than {module.__package__}" + ) + continue + # Recursively hash the sub-module hash_object = _hash_module(value, hash_object, seen_modules) @@ -688,6 +693,11 @@ def extract_task_updates_from_tracking_state( def _slurp_code(fg: graph.FunctionGraph, repo_base: str) -> List[dict]: + """Get the source code from modules. Returns a list with a dictionary for each module. + + The `path` attribute needs to match the `path` of code artifacts generated by + `extract_code_artifacts_from_function_graph()` + """ modules = set() for node_ in fg.nodes.values(): originating_functions = node_.originating_functions @@ -702,6 +712,14 @@ def _slurp_code(fg: graph.FunctionGraph, repo_base: str) -> List[dict]: module_path = os.path.relpath(module.__file__, repo_base) with open(module.__file__, "r") as f: out.append({"path": module_path, "contents": f.read()}) + # for temporary modules registed via `module_from_source` + else: + # get source code from the linecache; returns a tuple (size, mtime, lines, fullname) + source_lines = linecache.cache[module.__name__][2] + source = "".join(source_lines) + # the path won't have a `.py` suffix to match `extract_code_artifacts_from_function_grap()` + module_path = os.path.relpath(module.__name__, repo_base) + out.append({"path": module_path, "contents": source}) return out