diff --git a/mars/core/entrypoints.py b/mars/core/entrypoints.py new file mode 100644 index 0000000000..de94512920 --- /dev/null +++ b/mars/core/entrypoints.py @@ -0,0 +1,42 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings +import functools + +from pkg_resources import iter_entry_points + +logger = logging.getLogger(__name__) + + +# from https://github.com/numba/numba/blob/master/numba/core/entrypoints.py +# Must put this here to avoid extensions re-triggering initialization +@functools.lru_cache(maxsize=None) +def init_extension_entrypoints(): + """Execute all `mars_extensions` entry points with the name `init` + If extensions have already been initialized, this function does nothing. + """ + for entry_point in iter_entry_points("mars_extensions", "init"): + logger.info("Loading extension: %s", entry_point) + try: + func = entry_point.load() + func() + except Exception as e: + msg = "Mars extension module '{}' failed to load due to '{}({})'." + warnings.warn( + msg.format(entry_point.module_name, type(e).__name__, str(e)), + stacklevel=2, + ) + logger.info("Extension loading failed for: %s", entry_point) diff --git a/mars/core/tests/test_entrypoints.py b/mars/core/tests/test_entrypoints.py new file mode 100644 index 0000000000..33965060c6 --- /dev/null +++ b/mars/core/tests/test_entrypoints.py @@ -0,0 +1,124 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import types +import warnings +import pkg_resources + + +class _DummyClass(object): + def __init__(self, value): + self.value = value + + def __repr__(self): + return "_DummyClass(%f, %f)" % self.value + + +def test_init_entrypoint(): + # FIXME: Python 2 workaround because nonlocal doesn't exist + counters = {"init": 0} + + def init_function(): + counters["init"] += 1 + + mod = types.ModuleType("_test_mars_extension") + mod.init_func = init_function + + try: + # will remove this module at the end of the test + sys.modules[mod.__name__] = mod + + # We are registering an entry point using the "mars" package + # ("distribution" in pkg_resources-speak) itself, though these are + # normally registered by other packages. + dist = "pymars" + entrypoints = pkg_resources.get_entry_map(dist) + my_entrypoint = pkg_resources.EntryPoint( + "init", # name of entry point + mod.__name__, # module with entry point object + attrs=["init_func"], # name of entry point object + dist=pkg_resources.get_distribution(dist), + ) + entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint + + from .. import entrypoints + + # Allow reinitialization + entrypoints.init_extension_entrypoints.cache_clear() + + entrypoints.init_extension_entrypoints() + + # was our init function called? + assert counters["init"] == 1 + + # ensure we do not initialize twice + entrypoints.init_extension_entrypoints() + assert counters["init"] == 1 + finally: + # remove fake module + if mod.__name__ in sys.modules: + del sys.modules[mod.__name__] + + +def test_entrypoint_tolerance(): + # FIXME: Python 2 workaround because nonlocal doesn't exist + counters = {"init": 0} + + def init_function(): + counters["init"] += 1 + raise ValueError("broken") + + mod = types.ModuleType("_test_mars_bad_extension") + mod.init_func = init_function + + try: + # will remove this module at the end of the test + sys.modules[mod.__name__] = mod + + # We are registering an entry point using the "mars" package + # ("distribution" in pkg_resources-speak) itself, though these are + # normally registered by other packages. + dist = "pymars" + entrypoints = pkg_resources.get_entry_map(dist) + my_entrypoint = pkg_resources.EntryPoint( + "init", # name of entry point + mod.__name__, # module with entry point object + attrs=["init_func"], # name of entry point object + dist=pkg_resources.get_distribution(dist), + ) + entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint + + from .. import entrypoints + + # Allow reinitialization + entrypoints.init_extension_entrypoints.cache_clear() + + with warnings.catch_warnings(record=True) as w: + entrypoints.init_extension_entrypoints() + + bad_str = "Mars extension module '_test_mars_bad_extension'" + for x in w: + if bad_str in str(x): + break + else: + raise ValueError("Expected warning message not found") + + # was our init function called? + assert counters["init"] == 1 + + finally: + # remove fake module + if mod.__name__ in sys.modules: + del sys.modules[mod.__name__] diff --git a/mars/deploy/oscar/local.py b/mars/deploy/oscar/local.py index afaa737386..ed121f406c 100644 --- a/mars/deploy/oscar/local.py +++ b/mars/deploy/oscar/local.py @@ -22,6 +22,7 @@ import numpy as np from ... import oscar as mo +from ...core.entrypoints import init_extension_entrypoints from ...lib.aio import get_isolation, stop_isolation from ...resource import cpu_count, cuda_count from ...services import NodeRole @@ -111,6 +112,8 @@ def __init__( web: Union[bool, str] = "auto", timeout: float = None, ): + # load third party extensions. + init_extension_entrypoints() # load config file to dict. if not config or isinstance(config, str): config = load_config(config) diff --git a/mars/deploy/oscar/ray.py b/mars/deploy/oscar/ray.py index a6d911e06f..16ffe27971 100644 --- a/mars/deploy/oscar/ray.py +++ b/mars/deploy/oscar/ray.py @@ -20,6 +20,7 @@ from typing import Union, Dict, List, Optional, AsyncGenerator from ... import oscar as mo +from ...core.entrypoints import init_extension_entrypoints from ...oscar.backends.ray.driver import RayActorDriver from ...oscar.backends.ray.utils import ( process_placement_to_address, @@ -371,6 +372,8 @@ def __init__( worker_mem: int = 32 * 1024 ** 3, config: Union[str, Dict] = None, ): + # load third party extensions. + init_extension_entrypoints() self._cluster_name = cluster_name self._supervisor_mem = supervisor_mem self._worker_num = worker_num diff --git a/mars/deploy/oscar/session.py b/mars/deploy/oscar/session.py index 2368c8306f..40178f8e57 100644 --- a/mars/deploy/oscar/session.py +++ b/mars/deploy/oscar/session.py @@ -35,6 +35,7 @@ from ... import oscar as mo from ...config import options from ...core import ChunkType, TileableType, TileableGraph, enter_mode +from ...core.entrypoints import init_extension_entrypoints from ...core.operand import Fetch from ...lib.aio import ( alru_cache, @@ -1839,6 +1840,8 @@ def new_session( new: bool = True, **kwargs, ) -> AbstractSession: + # load third party extensions. + init_extension_entrypoints() ensure_isolation_created(kwargs) if address is None: diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index 17770f906a..5c22ccb9e3 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -23,6 +23,7 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Dict, List, Type, TypeVar, Coroutine, Callable, Union, Optional +from ...core.entrypoints import init_extension_entrypoints from ...utils import implements, to_binary from ...utils import lazy_import, register_asyncio_task_timeout_detector from ..api import Actor @@ -141,6 +142,8 @@ def __init__( self._asyncio_task_timeout_detector_task = ( register_asyncio_task_timeout_detector() ) + # load third party extensions. + init_extension_entrypoints() @property def router(self):