From c21ba92b96acbcb65d0b4c11fc3608f5ef31af66 Mon Sep 17 00:00:00 2001 From: Sachin Padmanabhan Date: Thu, 19 Dec 2024 22:37:18 -0800 Subject: [PATCH] support imports in pushed python code --- py/src/braintrust/cli/push.py | 77 +++++++++++++++++++++++++++++++---- 1 file changed, 68 insertions(+), 9 deletions(-) diff --git a/py/src/braintrust/cli/push.py b/py/src/braintrust/cli/push.py index 3e25904b..c02208ff 100644 --- a/py/src/braintrust/cli/push.py +++ b/py/src/braintrust/cli/push.py @@ -1,5 +1,7 @@ """Implements the braintrust push subcommand.""" +import importlib.abc +import importlib.machinery import importlib.metadata import importlib.util import inspect @@ -11,7 +13,7 @@ import tempfile import textwrap import zipfile -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence import requests @@ -71,19 +73,67 @@ def _check_uv(): ) -def _execute_module(file: str): - spec = importlib.util.spec_from_file_location("unused", file) +class _ProjectRootImporter(importlib.abc.MetaPathFinder): + """An importer that only resolves top-level modules from the project root and their submodules, + and collects the source files of all imported modules. + """ + + def __init__(self) -> None: + self._project_root, self._path_rest = sys.path[0], sys.path[1:] + self._sources = [] + + def _under_project_root(self, path: List[str]) -> bool: + """Returns true if all paths in `path` are under the project root.""" + return all(p.startswith(self._project_root) for p in path) + + def _under_rest(self, path: List[str]) -> bool: + """Returns true if any path in `path` is under one of the remaining paths in `sys.path`.""" + return any(p.startswith(pr) for p in path for pr in self._path_rest) + + def find_spec(self, fullname, path, target=None): + if path is None: + # Resolve top-level modules only from the project root. + path = [self._project_root] + elif not self._under_project_root(path) or self._under_rest(path): + # Defer paths that are not under the project root or covered by another sys.path entry + # to the subsequent importers. + return None + spec = importlib.machinery.PathFinder.find_spec(fullname, path, target) + if spec is not None and spec.origin is not None: + self._sources.append(spec.origin) + return spec + + def sources(self) -> List[str]: + return self._sources + + +def _import_module(name: str, path: str) -> List[str]: + """Imports the module and returns the list of source files + of all modules imported in the process. + + Args: + name: The fully qualified name of the module to import. + path: The absolute path to the module to import. + + Returns: + A list of absolute paths to source files of all modules imported in the process. + """ + importer = _ProjectRootImporter() + sys.meta_path.insert(0, importer) + + spec = importlib.util.spec_from_file_location(name, path) if spec is None or spec.loader is None: - raise ValueError(f"Failed to load module from {file}") + raise ValueError(f"Failed to load module from {path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) + return importer.sources() + [path] def _py_version() -> str: return f"{sys.version_info.major}.{sys.version_info.minor}" -def _upload_bundle(file: str, requirements: Optional[str]) -> str: +def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Optional[str]) -> str: _check_uv() resp = proxy_conn().post_json( @@ -131,7 +181,7 @@ def _upload_bundle(file: str, requirements: Optional[str]) -> str: "--target", packages_dir, "--python-platform", - "linux", + "aarch64-apple-darwin", "--python-version", _py_version(), ], @@ -153,7 +203,9 @@ def _upload_bundle(file: str, requirements: Optional[str]) -> str: if os.path.isfile(path): arcname = os.path.join(arcdirpath, name) zf.write(path, arcname) - zf.write(file, "register.py") + for source in sources: + zf.write(source, os.path.relpath(source)) + zf.writestr("register.py", f"import {entry_module_name} as _\n") with open(os.path.join(td, "pkg.zip"), "rb") as zf: requests.put(bundle_upload_url, data=zf.read()).raise_for_status() @@ -251,12 +303,19 @@ def run(args): app_url=args.app_url, ) - _execute_module(args.file) + if sys.path[0] != os.getcwd(): + raise ValueError( + f"The current working directory ({os.getcwd()}) is not the project root. " + "Please run the push command from the project root." + ) + path = os.path.abspath(args.file) + module_name = re.sub(r"\.py$", "", re.sub(r"[/\\]", ".", os.path.relpath(path))) + sources = _import_module(module_name, path) project_ids = _ProjectIdCache() functions: List[Dict[str, Any]] = [] if len(global_.functions) > 0: - bundle_id = _upload_bundle(args.file, args.requirements) + bundle_id = _upload_bundle(module_name, sources, args.requirements) _collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists) if len(global_.prompts) > 0: _collect_prompt_function_defs(project_ids, functions, args.if_exists)