From 692209b521a072f124aa94bef124be8ffa78f90f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Wed, 1 Dec 2021 15:22:31 +0800 Subject: [PATCH 1/3] Mars extension support --- mars/core/entrypoints.py | 42 ++++++++++ mars/core/tests/test_entrypoints.py | 124 ++++++++++++++++++++++++++++ mars/deploy/oscar/local.py | 3 + mars/deploy/oscar/ray.py | 3 + mars/deploy/oscar/session.py | 3 + mars/oscar/backends/pool.py | 3 + 6 files changed, 178 insertions(+) create mode 100644 mars/core/entrypoints.py create mode 100644 mars/core/tests/test_entrypoints.py diff --git a/mars/core/entrypoints.py b/mars/core/entrypoints.py new file mode 100644 index 0000000000..bd839e1c26 --- /dev/null +++ b/mars/core/entrypoints.py @@ -0,0 +1,42 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings +import functools + +from pkg_resources import iter_entry_points + +logger = logging.getLogger(__name__) + + +# from https://github.com/numba/numba/blob/master/numba/core/entrypoints.py +# Must put this here to avoid extensions re-triggering initialization +@functools.lru_cache(maxsize=None) +def init_all(): + """Execute all `mars_extensions` entry points with the name `init` + If extensions have already been initialized, this function does nothing. + """ + for entry_point in iter_entry_points("mars_extensions", "init"): + logger.info("Loading extension: %s", entry_point) + try: + func = entry_point.load() + func() + except Exception as e: + msg = "Mars extension module '{}' failed to load due to '{}({})'." + warnings.warn( + msg.format(entry_point.module_name, type(e).__name__, str(e)), + stacklevel=2, + ) + logger.info("Extension loading failed for: %s", entry_point) diff --git a/mars/core/tests/test_entrypoints.py b/mars/core/tests/test_entrypoints.py new file mode 100644 index 0000000000..07fab36310 --- /dev/null +++ b/mars/core/tests/test_entrypoints.py @@ -0,0 +1,124 @@ +# Copyright 1999-2021 Alibaba Group Holding Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import types +import warnings +import pkg_resources + + +class _DummyClass(object): + def __init__(self, value): + self.value = value + + def __repr__(self): + return "_DummyClass(%f, %f)" % self.value + + +def test_init_entrypoint(): + # FIXME: Python 2 workaround because nonlocal doesn't exist + counters = {"init": 0} + + def init_function(): + counters["init"] += 1 + + mod = types.ModuleType("_test_mars_extension") + mod.init_func = init_function + + try: + # will remove this module at the end of the test + sys.modules[mod.__name__] = mod + + # We are registering an entry point using the "mars" package + # ("distribution" in pkg_resources-speak) itself, though these are + # normally registered by other packages. + dist = "pymars" + entrypoints = pkg_resources.get_entry_map(dist) + my_entrypoint = pkg_resources.EntryPoint( + "init", # name of entry point + mod.__name__, # module with entry point object + attrs=["init_func"], # name of entry point object + dist=pkg_resources.get_distribution(dist), + ) + entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint + + from mars.core import entrypoints + + # Allow reinitialization + entrypoints.init_all.cache_clear() + + entrypoints.init_all() + + # was our init function called? + assert counters["init"] == 1 + + # ensure we do not initialize twice + entrypoints.init_all() + assert counters["init"] == 1 + finally: + # remove fake module + if mod.__name__ in sys.modules: + del sys.modules[mod.__name__] + + +def test_entrypoint_tolerance(): + # FIXME: Python 2 workaround because nonlocal doesn't exist + counters = {"init": 0} + + def init_function(): + counters["init"] += 1 + raise ValueError("broken") + + mod = types.ModuleType("_test_mars_bad_extension") + mod.init_func = init_function + + try: + # will remove this module at the end of the test + sys.modules[mod.__name__] = mod + + # We are registering an entry point using the "mars" package + # ("distribution" in pkg_resources-speak) itself, though these are + # normally registered by other packages. + dist = "pymars" + entrypoints = pkg_resources.get_entry_map(dist) + my_entrypoint = pkg_resources.EntryPoint( + "init", # name of entry point + mod.__name__, # module with entry point object + attrs=["init_func"], # name of entry point object + dist=pkg_resources.get_distribution(dist), + ) + entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint + + from mars.core import entrypoints + + # Allow reinitialization + entrypoints.init_all.cache_clear() + + with warnings.catch_warnings(record=True) as w: + entrypoints.init_all() + + bad_str = "Mars extension module '_test_mars_bad_extension'" + for x in w: + if bad_str in str(x): + break + else: + raise ValueError("Expected warning message not found") + + # was our init function called? + assert counters["init"] == 1 + + finally: + # remove fake module + if mod.__name__ in sys.modules: + del sys.modules[mod.__name__] diff --git a/mars/deploy/oscar/local.py b/mars/deploy/oscar/local.py index afaa737386..b2371303ed 100644 --- a/mars/deploy/oscar/local.py +++ b/mars/deploy/oscar/local.py @@ -22,6 +22,7 @@ import numpy as np from ... import oscar as mo +from ...core.entrypoints import init_all from ...lib.aio import get_isolation, stop_isolation from ...resource import cpu_count, cuda_count from ...services import NodeRole @@ -111,6 +112,8 @@ def __init__( web: Union[bool, str] = "auto", timeout: float = None, ): + # load third party extensions. + init_all() # load config file to dict. if not config or isinstance(config, str): config = load_config(config) diff --git a/mars/deploy/oscar/ray.py b/mars/deploy/oscar/ray.py index a6d911e06f..adfbb8a375 100644 --- a/mars/deploy/oscar/ray.py +++ b/mars/deploy/oscar/ray.py @@ -20,6 +20,7 @@ from typing import Union, Dict, List, Optional, AsyncGenerator from ... import oscar as mo +from ...core.entrypoints import init_all from ...oscar.backends.ray.driver import RayActorDriver from ...oscar.backends.ray.utils import ( process_placement_to_address, @@ -371,6 +372,8 @@ def __init__( worker_mem: int = 32 * 1024 ** 3, config: Union[str, Dict] = None, ): + # load third party extensions. + init_all() self._cluster_name = cluster_name self._supervisor_mem = supervisor_mem self._worker_num = worker_num diff --git a/mars/deploy/oscar/session.py b/mars/deploy/oscar/session.py index 2368c8306f..a605eca8a5 100644 --- a/mars/deploy/oscar/session.py +++ b/mars/deploy/oscar/session.py @@ -35,6 +35,7 @@ from ... import oscar as mo from ...config import options from ...core import ChunkType, TileableType, TileableGraph, enter_mode +from ...core.entrypoints import init_all from ...core.operand import Fetch from ...lib.aio import ( alru_cache, @@ -1839,6 +1840,8 @@ def new_session( new: bool = True, **kwargs, ) -> AbstractSession: + # load third party extensions. + init_all() ensure_isolation_created(kwargs) if address is None: diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index 17770f906a..21738517fc 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -23,6 +23,7 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Dict, List, Type, TypeVar, Coroutine, Callable, Union, Optional +from ...core.entrypoints import init_all from ...utils import implements, to_binary from ...utils import lazy_import, register_asyncio_task_timeout_detector from ..api import Actor @@ -141,6 +142,8 @@ def __init__( self._asyncio_task_timeout_detector_task = ( register_asyncio_task_timeout_detector() ) + # load third party extensions. + init_all() @property def router(self): From d94690d456af80f28692509ae3bef1de77b149c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Wed, 1 Dec 2021 16:30:00 +0800 Subject: [PATCH 2/3] Fix imports --- mars/core/tests/test_entrypoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mars/core/tests/test_entrypoints.py b/mars/core/tests/test_entrypoints.py index 07fab36310..4a7ee26d3d 100644 --- a/mars/core/tests/test_entrypoints.py +++ b/mars/core/tests/test_entrypoints.py @@ -53,7 +53,7 @@ def init_function(): ) entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint - from mars.core import entrypoints + from .. import entrypoints # Allow reinitialization entrypoints.init_all.cache_clear() @@ -100,7 +100,7 @@ def init_function(): ) entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint - from mars.core import entrypoints + from .. import entrypoints # Allow reinitialization entrypoints.init_all.cache_clear() From c5c01a21a2fc4cae91ac75a620c8023fa4690e30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Thu, 2 Dec 2021 11:00:18 +0800 Subject: [PATCH 3/3] Rename init_all to init_extension_entrypoints --- mars/core/entrypoints.py | 2 +- mars/core/tests/test_entrypoints.py | 10 +++++----- mars/deploy/oscar/local.py | 4 ++-- mars/deploy/oscar/ray.py | 4 ++-- mars/deploy/oscar/session.py | 4 ++-- mars/oscar/backends/pool.py | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mars/core/entrypoints.py b/mars/core/entrypoints.py index bd839e1c26..de94512920 100644 --- a/mars/core/entrypoints.py +++ b/mars/core/entrypoints.py @@ -24,7 +24,7 @@ # from https://github.com/numba/numba/blob/master/numba/core/entrypoints.py # Must put this here to avoid extensions re-triggering initialization @functools.lru_cache(maxsize=None) -def init_all(): +def init_extension_entrypoints(): """Execute all `mars_extensions` entry points with the name `init` If extensions have already been initialized, this function does nothing. """ diff --git a/mars/core/tests/test_entrypoints.py b/mars/core/tests/test_entrypoints.py index 4a7ee26d3d..33965060c6 100644 --- a/mars/core/tests/test_entrypoints.py +++ b/mars/core/tests/test_entrypoints.py @@ -56,15 +56,15 @@ def init_function(): from .. import entrypoints # Allow reinitialization - entrypoints.init_all.cache_clear() + entrypoints.init_extension_entrypoints.cache_clear() - entrypoints.init_all() + entrypoints.init_extension_entrypoints() # was our init function called? assert counters["init"] == 1 # ensure we do not initialize twice - entrypoints.init_all() + entrypoints.init_extension_entrypoints() assert counters["init"] == 1 finally: # remove fake module @@ -103,10 +103,10 @@ def init_function(): from .. import entrypoints # Allow reinitialization - entrypoints.init_all.cache_clear() + entrypoints.init_extension_entrypoints.cache_clear() with warnings.catch_warnings(record=True) as w: - entrypoints.init_all() + entrypoints.init_extension_entrypoints() bad_str = "Mars extension module '_test_mars_bad_extension'" for x in w: diff --git a/mars/deploy/oscar/local.py b/mars/deploy/oscar/local.py index b2371303ed..ed121f406c 100644 --- a/mars/deploy/oscar/local.py +++ b/mars/deploy/oscar/local.py @@ -22,7 +22,7 @@ import numpy as np from ... import oscar as mo -from ...core.entrypoints import init_all +from ...core.entrypoints import init_extension_entrypoints from ...lib.aio import get_isolation, stop_isolation from ...resource import cpu_count, cuda_count from ...services import NodeRole @@ -113,7 +113,7 @@ def __init__( timeout: float = None, ): # load third party extensions. - init_all() + init_extension_entrypoints() # load config file to dict. if not config or isinstance(config, str): config = load_config(config) diff --git a/mars/deploy/oscar/ray.py b/mars/deploy/oscar/ray.py index adfbb8a375..16ffe27971 100644 --- a/mars/deploy/oscar/ray.py +++ b/mars/deploy/oscar/ray.py @@ -20,7 +20,7 @@ from typing import Union, Dict, List, Optional, AsyncGenerator from ... import oscar as mo -from ...core.entrypoints import init_all +from ...core.entrypoints import init_extension_entrypoints from ...oscar.backends.ray.driver import RayActorDriver from ...oscar.backends.ray.utils import ( process_placement_to_address, @@ -373,7 +373,7 @@ def __init__( config: Union[str, Dict] = None, ): # load third party extensions. - init_all() + init_extension_entrypoints() self._cluster_name = cluster_name self._supervisor_mem = supervisor_mem self._worker_num = worker_num diff --git a/mars/deploy/oscar/session.py b/mars/deploy/oscar/session.py index a605eca8a5..40178f8e57 100644 --- a/mars/deploy/oscar/session.py +++ b/mars/deploy/oscar/session.py @@ -35,7 +35,7 @@ from ... import oscar as mo from ...config import options from ...core import ChunkType, TileableType, TileableGraph, enter_mode -from ...core.entrypoints import init_all +from ...core.entrypoints import init_extension_entrypoints from ...core.operand import Fetch from ...lib.aio import ( alru_cache, @@ -1841,7 +1841,7 @@ def new_session( **kwargs, ) -> AbstractSession: # load third party extensions. - init_all() + init_extension_entrypoints() ensure_isolation_created(kwargs) if address is None: diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index 21738517fc..5c22ccb9e3 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -23,7 +23,7 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Dict, List, Type, TypeVar, Coroutine, Callable, Union, Optional -from ...core.entrypoints import init_all +from ...core.entrypoints import init_extension_entrypoints from ...utils import implements, to_binary from ...utils import lazy_import, register_asyncio_task_timeout_detector from ..api import Actor @@ -143,7 +143,7 @@ def __init__( register_asyncio_task_timeout_detector() ) # load third party extensions. - init_all() + init_extension_entrypoints() @property def router(self):