Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix in-process main module source loading #3119

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/zenml/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,16 @@
from distutils.sysconfig import get_python_lib
from pathlib import Path, PurePath
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import Any, Callable, Dict, Iterator, Optional, Type, Union, cast
from typing import (
Any,
Callable,
Dict,
Iterator,
Optional,
Type,
Union,
cast,
)
from uuid import UUID

from zenml.config.source import (
Expand Down Expand Up @@ -121,7 +130,17 @@ def load(source: Union[Source, str]) -> Any:
# root in python path just to be sure
import_root = get_source_root()

module = _load_module(module_name=source.module, import_root=import_root)
if _should_load_from_main_module(source):
# This source points to the __main__ module of the current process.
# If we were to load the module here, we would load the same python
# file with a different module name, which would rerun all top-level
# code. To avoid this, we instead load the source from the __main__
# module which is already loaded.
module = sys.modules["__main__"]
else:
module = _load_module(
module_name=source.module, import_root=import_root
)

if source.attribute:
obj = getattr(module, source.attribute)
Expand Down Expand Up @@ -780,3 +799,21 @@ def get_resolved_notebook_sources() -> Dict[str, str]:
of their notebook cell.
"""
return _resolved_notebook_sources.copy()


def _should_load_from_main_module(source: Source) -> bool:
"""Check whether the source should be loaded from the main module.

Args:
source: The source to check.

Returns:
If the source should be loaded from the main module instead of the
module defined in the source object.
"""
try:
resolved_main_module = _resolve_module(sys.modules["__main__"])
except RuntimeError:
return False

return resolved_main_module == source.module
32 changes: 32 additions & 0 deletions tests/unit/utils/source_utils_test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
import sys

from zenml.utils import source_utils


def f():
pass


if __name__ == "__main__":
source = source_utils.resolve(f)
assert source.module != "__main__"
assert source.attribute == "f"
assert source.module not in sys.modules

obj = source_utils.load(source)

assert obj.__module__ == "__main__"
assert source.module not in sys.modules
11 changes: 11 additions & 0 deletions tests/unit/utils/test_source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

import os
import pathlib
import subprocess
import sys
from contextlib import ExitStack as does_not_raise
from types import BuiltinFunctionType, FunctionType
Expand Down Expand Up @@ -350,3 +352,12 @@ def test_package_utility_functions():
source_utils._get_package_version(package_name="non_existent_package")
is None
)


def test_resolving_and_loading_main_module_sources():
"""Test resolving and loading a main source in the same process."""

subprocess.check_call(
[sys.executable, "source_utils_test_helper.py"],
cwd=os.path.dirname(__file__),
)
Loading