Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports adding Mars extensions via setup entrypoints #2589

Merged
merged 3 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mars/core/entrypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import warnings
import functools

from pkg_resources import iter_entry_points

logger = logging.getLogger(__name__)


# from https://github.com/numba/numba/blob/master/numba/core/entrypoints.py
# Must put this here to avoid extensions re-triggering initialization
@functools.lru_cache(maxsize=None)
def init_all():
hekaisheng marked this conversation as resolved.
Show resolved Hide resolved
"""Execute all `mars_extensions` entry points with the name `init`
If extensions have already been initialized, this function does nothing.
"""
for entry_point in iter_entry_points("mars_extensions", "init"):
logger.info("Loading extension: %s", entry_point)
try:
func = entry_point.load()
func()
except Exception as e:
msg = "Mars extension module '{}' failed to load due to '{}({})'."
warnings.warn(
msg.format(entry_point.module_name, type(e).__name__, str(e)),
stacklevel=2,
)
logger.info("Extension loading failed for: %s", entry_point)
124 changes: 124 additions & 0 deletions mars/core/tests/test_entrypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import types
import warnings
import pkg_resources


class _DummyClass(object):
def __init__(self, value):
self.value = value

def __repr__(self):
return "_DummyClass(%f, %f)" % self.value


def test_init_entrypoint():
# FIXME: Python 2 workaround because nonlocal doesn't exist
counters = {"init": 0}

def init_function():
counters["init"] += 1

mod = types.ModuleType("_test_mars_extension")
mod.init_func = init_function

try:
# will remove this module at the end of the test
sys.modules[mod.__name__] = mod

# We are registering an entry point using the "mars" package
# ("distribution" in pkg_resources-speak) itself, though these are
# normally registered by other packages.
dist = "pymars"
entrypoints = pkg_resources.get_entry_map(dist)
my_entrypoint = pkg_resources.EntryPoint(
"init", # name of entry point
mod.__name__, # module with entry point object
attrs=["init_func"], # name of entry point object
dist=pkg_resources.get_distribution(dist),
)
entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint

from .. import entrypoints

# Allow reinitialization
entrypoints.init_all.cache_clear()

entrypoints.init_all()

# was our init function called?
assert counters["init"] == 1

# ensure we do not initialize twice
entrypoints.init_all()
assert counters["init"] == 1
finally:
# remove fake module
if mod.__name__ in sys.modules:
del sys.modules[mod.__name__]


def test_entrypoint_tolerance():
# FIXME: Python 2 workaround because nonlocal doesn't exist
counters = {"init": 0}

def init_function():
counters["init"] += 1
raise ValueError("broken")

mod = types.ModuleType("_test_mars_bad_extension")
mod.init_func = init_function

try:
# will remove this module at the end of the test
sys.modules[mod.__name__] = mod

# We are registering an entry point using the "mars" package
# ("distribution" in pkg_resources-speak) itself, though these are
# normally registered by other packages.
dist = "pymars"
entrypoints = pkg_resources.get_entry_map(dist)
my_entrypoint = pkg_resources.EntryPoint(
"init", # name of entry point
mod.__name__, # module with entry point object
attrs=["init_func"], # name of entry point object
dist=pkg_resources.get_distribution(dist),
)
entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint

from .. import entrypoints

# Allow reinitialization
entrypoints.init_all.cache_clear()

with warnings.catch_warnings(record=True) as w:
entrypoints.init_all()

bad_str = "Mars extension module '_test_mars_bad_extension'"
for x in w:
if bad_str in str(x):
break
else:
raise ValueError("Expected warning message not found")

# was our init function called?
assert counters["init"] == 1

finally:
# remove fake module
if mod.__name__ in sys.modules:
del sys.modules[mod.__name__]
3 changes: 3 additions & 0 deletions mars/deploy/oscar/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np

from ... import oscar as mo
from ...core.entrypoints import init_all
from ...lib.aio import get_isolation, stop_isolation
from ...resource import cpu_count, cuda_count
from ...services import NodeRole
Expand Down Expand Up @@ -111,6 +112,8 @@ def __init__(
web: Union[bool, str] = "auto",
timeout: float = None,
):
# load third party extensions.
init_all()
# load config file to dict.
if not config or isinstance(config, str):
config = load_config(config)
Expand Down
3 changes: 3 additions & 0 deletions mars/deploy/oscar/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Union, Dict, List, Optional, AsyncGenerator

from ... import oscar as mo
from ...core.entrypoints import init_all
from ...oscar.backends.ray.driver import RayActorDriver
from ...oscar.backends.ray.utils import (
process_placement_to_address,
Expand Down Expand Up @@ -371,6 +372,8 @@ def __init__(
worker_mem: int = 32 * 1024 ** 3,
config: Union[str, Dict] = None,
):
# load third party extensions.
init_all()
self._cluster_name = cluster_name
self._supervisor_mem = supervisor_mem
self._worker_num = worker_num
Expand Down
3 changes: 3 additions & 0 deletions mars/deploy/oscar/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ... import oscar as mo
from ...config import options
from ...core import ChunkType, TileableType, TileableGraph, enter_mode
from ...core.entrypoints import init_all
from ...core.operand import Fetch
from ...lib.aio import (
alru_cache,
Expand Down Expand Up @@ -1839,6 +1840,8 @@ def new_session(
new: bool = True,
**kwargs,
) -> AbstractSession:
# load third party extensions.
init_all()
ensure_isolation_created(kwargs)

if address is None:
Expand Down
3 changes: 3 additions & 0 deletions mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, List, Type, TypeVar, Coroutine, Callable, Union, Optional

from ...core.entrypoints import init_all
from ...utils import implements, to_binary
from ...utils import lazy_import, register_asyncio_task_timeout_detector
from ..api import Actor
Expand Down Expand Up @@ -141,6 +142,8 @@ def __init__(
self._asyncio_task_timeout_detector_task = (
register_asyncio_task_timeout_detector()
)
# load third party extensions.
init_all()

@property
def router(self):
Expand Down