Skip to content

Commit

Permalink
support to detect shell activated runtime uri for model build
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Dec 15, 2023
1 parent 4247130 commit 01583ee
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 21 deletions.
6 changes: 6 additions & 0 deletions client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
ENV_BUILD_BUNDLE_FIXED_VERSION_FOR_TEST = "SW_BUILD_BUNDLE_FIXED_VERSION_FOR_TEST"


class SWShellActivatedRuntimeEnv:
URI = "SW_ACTIVATED_RUNTIME_URI_IN_SHELL"
MODE = "SW_ACTIVATED_RUNTIME_MODE_IN_SHELL"
PREFIX = "SW_ACTIVATED_RUNTIME_PREFIX_IN_SHELL"


class DefaultYAMLName:
MODEL = "model.yaml"
DATASET = "dataset.yaml"
Expand Down
11 changes: 3 additions & 8 deletions client/starwhale/core/model/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,9 @@ def build(
typ=ResourceType.model,
)
m = Model.get_model(model_uri)

if package_runtime:
packaging_runtime_uri = os.environ.get(
RuntimeProcess.ActivatedRuntimeURI
)
else:
packaging_runtime_uri = None

packaging_runtime_uri = (
RuntimeProcess.get_activated_runtime_uri() if package_runtime else None
)
m.build(
Path(workdir),
model_config=model_config,
Expand Down
5 changes: 4 additions & 1 deletion client/starwhale/core/runtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,10 @@ def activate(cls, uri: Resource, force_restore: bool = False) -> None:

console.print(f":carrot: activate the current shell for the runtime uri: {uri}")
activate_python_env(
mode=mode, identity=str(prefix_path.resolve()), interactive=True
mode=mode,
identity=str(prefix_path.resolve()),
interactive=True,
original_runtime_uri=uri.full_uri,
)

@classmethod
Expand Down
52 changes: 51 additions & 1 deletion client/starwhale/core/runtime/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from functools import partial

from starwhale.utils import console
from starwhale.consts import PythonRunEnv
from starwhale.consts import (
ENV_VENV,
PythonRunEnv,
ENV_CONDA_PREFIX,
SWShellActivatedRuntimeEnv,
)
from starwhale.utils.fs import extract_tar
from starwhale.utils.venv import (
get_conda_bin,
Expand Down Expand Up @@ -148,3 +153,48 @@ def _restore_runtime(
)

return prefix.resolve()

@classmethod
def get_activated_runtime_uri(cls) -> str | None:
_env = os.environ.get
# process activated runtime is the first priority
uri = _env(cls.ActivatedRuntimeURI)
if uri:
return uri

# detect shell activated runtime
uri = _env(SWShellActivatedRuntimeEnv.URI)
if not uri:
return None

mode = _env(SWShellActivatedRuntimeEnv.MODE)
prefix = _env(SWShellActivatedRuntimeEnv.PREFIX)
if not mode or not prefix:
console.debug(
f"env {SWShellActivatedRuntimeEnv.MODE} or {SWShellActivatedRuntimeEnv.PREFIX} is not set, skip detect"
)
return None

if mode == PythonRunEnv.VENV:
env_prefix = _env(ENV_VENV, "")
if prefix != env_prefix:
console.debug(
f"venv prefix: {prefix} is not equal to {ENV_VENV} env({env_prefix}), skip detect"
)
return None
else:
return uri
elif mode == PythonRunEnv.CONDA:
env_prefix = _env(ENV_CONDA_PREFIX, "")
if prefix != env_prefix:
console.debug(
f"conda prefix: {prefix} is not equal to {ENV_CONDA_PREFIX} env({env_prefix}), skip detect"
)
return None
else:
return uri
else:
console.debug(
f"{SWShellActivatedRuntimeEnv.MODE}={mode} is not supported to detect shell activated runtime, skip detect"
)
return None
23 changes: 18 additions & 5 deletions client/starwhale/utils/venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SW_DEV_DUMMY_VERSION,
WHEEL_FILE_EXTENSION,
DEFAULT_CONDA_CHANNEL,
SWShellActivatedRuntimeEnv,
)
from starwhale.version import STARWHALE_VERSION
from starwhale.utils.fs import ensure_dir, ensure_file, extract_tar
Expand Down Expand Up @@ -666,7 +667,9 @@ def package_python_env(
return True


def activate_python_env(mode: str, identity: str, interactive: bool) -> None:
def activate_python_env(
mode: str, identity: str, interactive: bool, original_runtime_uri: str = ""
) -> None:
if mode == PythonRunEnv.VENV:
cmd = f"source {identity}/bin/activate"
elif mode == PythonRunEnv.CONDA:
Expand All @@ -685,24 +688,34 @@ def activate_python_env(mode: str, identity: str, interactive: bool) -> None:
if not _bin.startswith("/") or _name == _bin:
_bin = shutil.which(_name) or _bin

envs = os.environ.copy()
envs[SWShellActivatedRuntimeEnv.MODE] = mode
envs[SWShellActivatedRuntimeEnv.PREFIX] = identity
envs[SWShellActivatedRuntimeEnv.URI] = original_runtime_uri

if _name == "zsh":
# https://zsh.sourceforge.io/Intro/intro_3.html
os.execl(
os.execle(
_bin,
_bin,
"-c",
f"""temp_dir={identity} && \
echo ". $HOME/.zshrc && {cmd}" > $temp_dir/.zshrc && \
ZDOTDIR=$temp_dir zsh -i""",
envs,
)
elif _name == "bash":
# https://www.gnu.org/software/bash/manual/html_node/Bash-Startup-Files.html
os.execl(
_bin, _bin, "-c", f'bash --rcfile <(echo ". "$HOME/.bashrc" && {cmd}")'
os.execle(
_bin,
_bin,
"-c",
f'bash --rcfile <(echo ". "$HOME/.bashrc" && {cmd}")',
envs,
)
elif _name == "fish":
# https://fishshell.com/docs/current/language.html#configuration
os.execl(_bin, _bin, "-C", cmd)
os.execle(_bin, _bin, "-C", cmd, envs)

# user executes the command manually
console.print(":cake: run command in shell :cake:")
Expand Down
10 changes: 5 additions & 5 deletions client/tests/core/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3212,10 +3212,10 @@ def test_dockerize(self, m_check: MagicMock) -> None:
@patch("starwhale.core.runtime.model.StandaloneRuntime.restore")
@patch("starwhale.core.runtime.model.StandaloneRuntime.extract")
@patch("shellingham.detect_shell")
@patch("os.execl")
@patch("os.execle")
def test_activate(
self,
m_execl: MagicMock,
m_execle: MagicMock,
m_detect: MagicMock,
m_extract: MagicMock,
m_restore: MagicMock,
Expand Down Expand Up @@ -3243,7 +3243,7 @@ def test_activate(
m_detect.return_value = ["zsh", "/usr/bin/zsh"]
uri = Resource(f"{name}/version/{version}", typ=ResourceType.runtime)
StandaloneRuntime.activate(uri=uri)
assert m_execl.call_args[0][0] == "/usr/bin/zsh"
assert m_execle.call_args[0][0] == "/usr/bin/zsh"
assert not m_extract.called
assert not m_restore.called

Expand All @@ -3265,14 +3265,14 @@ def test_activate(
StandaloneRuntime.activate(uri=uri, force_restore=False)
assert m_restore.called

m_execl.reset_mock()
m_execle.reset_mock()
runtime_config = self.get_runtime_config()
runtime_config["mode"] = "conda"
ensure_file(
snapshot_dir / DefaultYAMLName.RUNTIME, yaml.safe_dump(runtime_config)
)

m_execl.reset_mock()
m_execle.reset_mock()
m_detect.return_value = ["bash", "/usr/bin/bash"]
StandaloneRuntime.activate(uri=uri, force_restore=True)
assert not m_extract.called
Expand Down
36 changes: 35 additions & 1 deletion client/tests/core/test_runtime_process.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
import tempfile
from pathlib import Path
Expand All @@ -6,7 +7,7 @@
from requests_mock import Mocker
from pyfakefs.fake_filesystem_unittest import TestCase

from starwhale.consts import DEFAULT_MANIFEST_NAME
from starwhale.consts import DEFAULT_MANIFEST_NAME, SWShellActivatedRuntimeEnv
from starwhale.utils.fs import empty_dir, ensure_dir, ensure_file
from starwhale.utils.error import NoSupportError, FieldTypeOrValueError
from starwhale.core.model.store import ModelStorage
Expand Down Expand Up @@ -210,3 +211,36 @@ def test_run_exceptions(
):
uri = "http://1.1.1.1:8081/projects/self/runtimes/rttest/versoin/123"
Process(uri).run()

@patch("os.environ", {})
def test_activated_runtime_uri(self) -> None:
get_uri = Process.get_activated_runtime_uri
os.environ[Process.ActivatedRuntimeURI] = "runtime-in-process"
assert get_uri() == "runtime-in-process"

del os.environ[Process.ActivatedRuntimeURI]

os.environ[SWShellActivatedRuntimeEnv.MODE] = "system"
assert get_uri() is None

os.environ[SWShellActivatedRuntimeEnv.URI] = ""
assert get_uri() is None

os.environ[SWShellActivatedRuntimeEnv.URI] = "runtime-in-shell-venv"
assert get_uri() is None

os.environ[SWShellActivatedRuntimeEnv.MODE] = "venv"
os.environ[SWShellActivatedRuntimeEnv.PREFIX] = "/opt/sw/venv"
assert get_uri() is None

os.environ["VIRTUAL_ENV"] = "/opt/sw/venv"
assert get_uri() == "runtime-in-shell-venv"
del os.environ["VIRTUAL_ENV"]

os.environ[SWShellActivatedRuntimeEnv.URI] = "runtime-in-shell-conda"
os.environ[SWShellActivatedRuntimeEnv.MODE] = "conda"
os.environ[SWShellActivatedRuntimeEnv.PREFIX] = "/opt/sw/conda"
assert get_uri() is None

os.environ["CONDA_PREFIX"] = "/opt/sw/conda"
assert get_uri() == "runtime-in-shell-conda"

0 comments on commit 01583ee

Please sign in to comment.