Skip to content

Commit

Permalink
support imports in pushed python code
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinpad committed Dec 21, 2024
1 parent 52cd156 commit b440020
Showing 1 changed file with 66 additions and 11 deletions.
77 changes: 66 additions & 11 deletions py/src/braintrust/cli/push.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Implements the braintrust push subcommand."""

import importlib.abc
import importlib.machinery
import importlib.metadata
import importlib.util
import inspect
Expand All @@ -11,7 +13,7 @@
import tempfile
import textwrap
import zipfile
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

import requests

Expand Down Expand Up @@ -71,19 +73,63 @@ def _check_uv():
)


def _execute_module(file: str):
spec = importlib.util.spec_from_file_location("unused", file)
if spec is None or spec.loader is None:
raise ValueError(f"Failed to load module from {file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
class _ProjectRootImporter(importlib.abc.MetaPathFinder):
"""An importer that only resolves top-level modules from the project root and their submodules,
and collects the source files of all imported modules.
"""

def __init__(self) -> None:
self._project_root, self._path_rest = sys.path[0], sys.path[1:]
self._sources = []

def _under_project_root(self, path: List[str]) -> bool:
"""Returns true if all paths in `path` are under the project root."""
return all(p.startswith(self._project_root) for p in path)

def _under_rest(self, path: List[str]) -> bool:
"""Returns true if any path in `path` is under one of the remaining paths in `sys.path`."""
return any(p.startswith(pr) for p in path for pr in self._path_rest)

def find_spec(self, fullname, path, target=None):
if path is None:
# Resolve top-level modules only from the project root.
path = [self._project_root]
elif not self._under_project_root(path) or self._under_rest(path):
# Defer paths that are not under the project root or covered by another sys.path entry
# to the subsequent importers.
return None
spec = importlib.machinery.PathFinder.find_spec(fullname, path, target)
if spec is not None and spec.origin is not None:
self._sources.append(spec.origin)
return spec

def sources(self) -> List[str]:
return self._sources


def _import_module(name: str, path: str) -> List[str]:
"""Imports the module and returns the list of source files
of all modules imported in the process.
Args:
name: The fully qualified name of the module to import.
path: The absolute path to the module to import.
Returns:
A list of absolute paths to source files of all modules imported in the process.
"""
importer = _ProjectRootImporter()
sys.meta_path.insert(0, importer)

importlib.import_module(name)
return importer.sources()


def _py_version() -> str:
return f"{sys.version_info.major}.{sys.version_info.minor}"


def _upload_bundle(file: str, requirements: Optional[str]) -> str:
def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Optional[str]) -> str:
_check_uv()

resp = proxy_conn().post_json(
Expand Down Expand Up @@ -153,7 +199,9 @@ def _upload_bundle(file: str, requirements: Optional[str]) -> str:
if os.path.isfile(path):
arcname = os.path.join(arcdirpath, name)
zf.write(path, arcname)
zf.write(file, "register.py")
for source in sources:
zf.write(source, os.path.relpath(source))
zf.writestr("register.py", f"import {entry_module_name} as _\n")
with open(os.path.join(td, "pkg.zip"), "rb") as zf:
requests.put(bundle_upload_url, data=zf.read()).raise_for_status()

Expand Down Expand Up @@ -251,12 +299,19 @@ def run(args):
app_url=args.app_url,
)

_execute_module(args.file)
if sys.path[0] != os.getcwd():
raise ValueError(
f"The current working directory ({os.getcwd()}) is not the project root. "
"Please run the push command from the project root."
)
path = os.path.abspath(args.file)
module_name = re.sub(r"\.py$", "", re.sub(r"[/\\]", ".", os.path.relpath(path)))
sources = _import_module(module_name, path)

project_ids = _ProjectIdCache()
functions: List[Dict[str, Any]] = []
if len(global_.functions) > 0:
bundle_id = _upload_bundle(args.file, args.requirements)
bundle_id = _upload_bundle(module_name, sources, args.requirements)
_collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
if len(global_.prompts) > 0:
_collect_prompt_function_defs(project_ids, functions, args.if_exists)
Expand Down

0 comments on commit b440020

Please sign in to comment.