Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support imports in pushed python code #496

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions py/src/braintrust/cli/push.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Implements the braintrust push subcommand."""

import importlib.abc
import importlib.machinery
import importlib.metadata
import importlib.util
import inspect
Expand Down Expand Up @@ -71,19 +73,63 @@ def _check_uv():
)


def _execute_module(file: str):
spec = importlib.util.spec_from_file_location("unused", file)
if spec is None or spec.loader is None:
raise ValueError(f"Failed to load module from {file}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
class _ProjectRootImporter(importlib.abc.MetaPathFinder):
"""An importer that only resolves top-level modules from the project root and their submodules,
and collects the source files of all imported modules.
"""

def __init__(self) -> None:
self._project_root, self._path_rest = sys.path[0], sys.path[1:]
self._sources = []

def _under_project_root(self, path: List[str]) -> bool:
"""Returns true if all paths in `path` are under the project root."""
return all(p.startswith(self._project_root) for p in path)

def _under_rest(self, path: List[str]) -> bool:
"""Returns true if any path in `path` is under one of the remaining paths in `sys.path`."""
return any(p.startswith(pr) for p in path for pr in self._path_rest)

def find_spec(self, fullname, path, target=None):
if path is None:
# Resolve top-level modules only from the project root.
path = [self._project_root]
elif not self._under_project_root(path) or self._under_rest(path):
# Defer paths that are not under the project root or covered by another sys.path entry
# to the subsequent importers.
return None
spec = importlib.machinery.PathFinder.find_spec(fullname, path, target)
if spec is not None and spec.origin is not None:
self._sources.append(spec.origin)
return spec

def sources(self) -> List[str]:
return self._sources


def _import_module(name: str, path: str) -> List[str]:
"""Imports the module and returns the list of source files
of all modules imported in the process.

Args:
name: The fully qualified name of the module to import.
path: The absolute path to the module to import.

Returns:
A list of absolute paths to source files of all modules imported in the process.
"""
importer = _ProjectRootImporter()
sys.meta_path.insert(0, importer)

importlib.import_module(name)
return importer.sources()


def _py_version() -> str:
return f"{sys.version_info.major}.{sys.version_info.minor}"


def _upload_bundle(file: str, requirements: Optional[str]) -> str:
def _upload_bundle(entry_module_name: str, sources: List[str], requirements: Optional[str]) -> str:
_check_uv()

resp = proxy_conn().post_json(
Expand Down Expand Up @@ -153,7 +199,9 @@ def _upload_bundle(file: str, requirements: Optional[str]) -> str:
if os.path.isfile(path):
arcname = os.path.join(arcdirpath, name)
zf.write(path, arcname)
zf.write(file, "register.py")
for source in sources:
zf.write(source, os.path.relpath(source))
zf.writestr("register.py", f"import {entry_module_name} as _\n")
with open(os.path.join(td, "pkg.zip"), "rb") as zf:
requests.put(bundle_upload_url, data=zf.read()).raise_for_status()

Expand Down Expand Up @@ -251,12 +299,19 @@ def run(args):
app_url=args.app_url,
)

_execute_module(args.file)
if sys.path[0] != os.getcwd():
raise ValueError(
f"The current working directory ({os.getcwd()}) is not the project root. "
"Please run the push command from the project root."
)
path = os.path.abspath(args.file)
module_name = re.sub(".py$", "", re.sub("/", ".", os.path.relpath(path)))
sources = _import_module(module_name, path)

project_ids = _ProjectIdCache()
functions: List[Dict[str, Any]] = []
if len(global_.functions) > 0:
bundle_id = _upload_bundle(args.file, args.requirements)
bundle_id = _upload_bundle(module_name, sources, args.requirements)
_collect_function_function_defs(project_ids, functions, bundle_id, args.if_exists)
if len(global_.prompts) > 0:
_collect_prompt_function_defs(project_ids, functions, args.if_exists)
Expand Down
Loading