From 784fbd60820c387cb0ca3d338b6e01960888db2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Thu, 5 May 2022 10:49:42 +0800 Subject: [PATCH 1/8] Ray execution state --- mars/services/task/execution/mars/executor.py | 12 +++ mars/services/task/execution/ray/context.py | 90 ++++++++++++++++++- mars/services/task/execution/ray/executor.py | 66 ++++++++++---- mars/services/task/supervisor/manager.py | 13 --- .../supervisor/tests/test_task_manager.py | 5 +- .../tests/test_task_manager_on_ray_dag.py | 6 +- 6 files changed, 154 insertions(+), 38 deletions(-) diff --git a/mars/services/task/execution/mars/executor.py b/mars/services/task/execution/mars/executor.py index 7ee7f6ae43..3bbea5b261 100644 --- a/mars/services/task/execution/mars/executor.py +++ b/mars/services/task/execution/mars/executor.py @@ -20,6 +20,7 @@ from ..... import oscar as mo from .....core import ChunkGraph, TileContext +from .....core.context import set_context from .....core.operand import ( Fetch, MapReduceOperand, @@ -33,6 +34,7 @@ from .....resource import Resource from .....typing import TileableType, BandType from .....utils import Timer +from ....context import ThreadedServiceContext from ....cluster.api import ClusterAPI from ....lifecycle.api import LifecycleAPI from ....meta.api import MetaAPI @@ -111,6 +113,7 @@ async def create( cluster_api, lifecycle_api, scheduling_api, meta_api = await cls._get_apis( session_id, address ) + await cls._init_context(session_id, address) return cls( config, task, @@ -131,6 +134,15 @@ async def _get_apis(cls, session_id: str, address: str): MetaAPI.create(session_id, address), ) + @classmethod + async def _init_context(cls, session_id: str, address: str): + loop = asyncio.get_running_loop() + context = ThreadedServiceContext( + session_id, address, address, address, loop=loop + ) + await context.init() + set_context(context) + async def __aenter__(self): profiling = ProfilingData[self._task.task_id, "general"] # incref fetch tileables to ensure fetch data not deleted diff --git a/mars/services/task/execution/ray/context.py b/mars/services/task/execution/ray/context.py index d38d7a1a18..ba74746900 100644 --- a/mars/services/task/execution/ray/context.py +++ b/mars/services/task/execution/ray/context.py @@ -12,9 +12,95 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import inspect + +from .....core.context import Context +from .....utils import implements, lazy_import +from ....context import ThreadedServiceContext + +ray = lazy_import("ray") + + +class RayRemoteObjectManager: + """The remote object manager in task state actor.""" + + def __init__(self): + self._named_remote_objects = {} + + def create_remote_object(self, name: str, object_cls, *args, **kwargs): + remote_object = object_cls(*args, **kwargs) + self._named_remote_objects[name] = remote_object + + def destroy_remote_object(self, name: str): + self._named_remote_objects.pop(name, None) + + async def call_remote_object(self, name: str, attr: str, *args, **kwargs): + remote_object = self._named_remote_objects[name] + meth = getattr(remote_object, attr) + async_meth = self._sync_to_async(meth) + return await async_meth(*args, **kwargs) + + @staticmethod + @functools.lru_cache + def _sync_to_async(func): + if inspect.iscoroutinefunction(func): + return func + else: + + async def async_wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return async_wrapper + + +class _RayRemoteObjectWrapper: + def __init__(self, task_state_actor: "ray.actor.ActorHandle", name: str): + self._task_state_actor = task_state_actor + self._name = name + + def __getattr__(self, attr): + def wrap(*args, **kwargs): + r = self._task_state_actor.call_remote_object.remote( + self._name, attr, *args, **kwargs + ) + return ray.get(r) + + return wrap + + +class _RayRemoteObjectContext: + def __init__(self, task_state_actor: "ray.actor.ActorHandle", *args, **kwargs): + super().__init__(*args, **kwargs) + self._task_state_actor = task_state_actor + + @implements(Context.create_remote_object) + def create_remote_object(self, name: str, object_cls, *args, **kwargs): + self._task_state_actor.create_remote_object.remote( + name, object_cls, *args, **kwargs + ) + return _RayRemoteObjectWrapper(self._task_state_actor, name) + + @implements(Context.get_remote_object) + def get_remote_object(self, name: str): + return _RayRemoteObjectWrapper(self._task_state_actor, name) + + @implements(Context.destroy_remote_object) + def destroy_remote_object(self, name: str): + self._task_state_actor.destroy_remote_object.remote(name) + + +# TODO(fyrestone): Implement more APIs for Ray. +class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext): + """The context for tiling.""" + + pass + + +# TODO(fyrestone): Implement more APIs for Ray. +class RayExecutionWorkerContext(_RayRemoteObjectContext, dict): + """The context for executing operands.""" -# TODO(fyrestone): Should implement the mars.core.context.Context. -class RayExecutionContext(dict): @staticmethod def new_custom_log_dir(): return None diff --git a/mars/services/task/execution/ray/executor.py b/mars/services/task/execution/ray/executor.py index 84b8ef88b6..bf2e56d0c6 100644 --- a/mars/services/task/execution/ray/executor.py +++ b/mars/services/task/execution/ray/executor.py @@ -16,7 +16,9 @@ import logging from typing import List, Dict, Any, Set from .....core import ChunkGraph, Chunk, TileContext +from .....core.context import set_context from .....core.operand import ( + Fetch, Fuse, VirtualOperand, MapReduceOperand, @@ -43,13 +45,22 @@ ExecutionChunkResult, register_executor_cls, ) -from .context import RayExecutionContext +from .context import ( + RayExecutionContext, + RayExecutionWorkerContext, + RayRemoteObjectManager, +) ray = lazy_import("ray") logger = logging.getLogger(__name__) +class RayTaskState(RayRemoteObjectManager): + pass + + def execute_subtask( + task_state_actor: "ray.actor.ActorHandle", subtask_id: str, subtask_chunk_graph: ChunkGraph, output_meta_keys: Set[str], @@ -60,7 +71,7 @@ def execute_subtask( ensure_coverage() subtask_chunk_graph = deserialize(*subtask_chunk_graph) # inputs = [i[1] for i in inputs] - context = RayExecutionContext(zip(input_keys, inputs)) + context = RayExecutionWorkerContext(task_state_actor, zip(input_keys, inputs)) # optimize chunk graph. subtask_chunk_graph = optimize(subtask_chunk_graph) # from data_key to results @@ -100,6 +111,7 @@ def __init__( ray_executor, lifecycle_api, meta_api, + task_state_actor, ): self._config = config self._task = task @@ -110,6 +122,7 @@ def __init__( self._lifecycle_api = lifecycle_api self._meta_api = meta_api + self._task_state_actor = task_state_actor self._task_context = {} self._available_band_resources = None @@ -126,6 +139,8 @@ async def create( ) -> "TaskExecutor": ray_executor = ray.remote(execute_subtask) lifecycle_api, meta_api = await cls._get_apis(session_id, address) + task_state_actor = ray.remote(RayTaskState).remote() + await cls._init_context(task_state_actor, session_id, address) return cls( config, task, @@ -133,6 +148,7 @@ async def create( ray_executor, lifecycle_api, meta_api, + task_state_actor, ) @classmethod @@ -143,6 +159,17 @@ async def _get_apis(cls, session_id: str, address: str): MetaAPI.create(session_id, address), ) + @classmethod + async def _init_context( + cls, task_state_actor: "ray.actor.ActorHandle", session_id: str, address: str + ): + loop = asyncio.get_running_loop() + context = RayExecutionContext( + task_state_actor, session_id, address, address, address, loop=loop + ) + await context.init() + set_context(context) + async def execute_subtask_graph( self, stage_id: str, @@ -156,19 +183,23 @@ async def execute_subtask_graph( output_meta_object_refs = [] logger.info("Submitting %s subtasks of stage %s.", len(subtask_graph), stage_id) - # TODO(fyrestone): Filter out the Fetch chunk. - result_keys = {chunk.key for chunk in chunk_graph.result_chunks} + result_meta_keys = { + chunk.key + for chunk in chunk_graph.result_chunks + if not isinstance(chunk.op, Fetch) + } for subtask in subtask_graph.topological_iter(): subtask_chunk_graph = subtask.chunk_graph key_to_input = await self._load_subtask_inputs( stage_id, subtask, subtask_chunk_graph, context ) output_keys = self._get_subtask_output_keys(subtask_chunk_graph) - output_meta_keys = result_keys & output_keys + output_meta_keys = result_meta_keys & output_keys output_count = len(output_keys) + bool(output_meta_keys) output_object_refs = self._ray_executor.options( num_returns=output_count ).remote( + self._task_state_actor, subtask.subtask_id, serialize(subtask_chunk_graph), output_meta_keys, @@ -186,28 +217,31 @@ async def execute_subtask_graph( context.update(zip(output_keys, output_object_refs)) logger.info("Submitted %s subtasks of stage %s.", len(subtask_graph), stage_id) - assert len(output_meta_object_refs) > 0 key_to_meta = {} - meta_list = await asyncio.gather(*output_meta_object_refs) - for meta in meta_list: - key_to_meta.update(meta) - assert len(key_to_meta) == len(chunk_graph.result_chunks) - logger.info("Got %s metas of stage %s.", len(output_meta_object_refs), stage_id) + if len(output_meta_object_refs) > 0: + # TODO(fyrestone): Optimize update meta by fetching partial meta. + meta_list = await asyncio.gather(*output_meta_object_refs) + for meta in meta_list: + key_to_meta.update(meta) + assert len(key_to_meta) == len(result_meta_keys) + logger.info( + "Got %s metas of stage %s.", len(output_meta_object_refs), stage_id + ) - chunk_to_result = {} + chunk_to_meta = {} output_object_refs = [] for chunk in chunk_graph.result_chunks: chunk_key = chunk.key object_ref = context[chunk_key] output_object_refs.append(object_ref) - chunk_to_result[chunk] = ExecutionChunkResult( - key_to_meta[chunk_key], object_ref - ) + chunk_meta = key_to_meta.get(chunk_key) + if chunk_meta is not None: + chunk_to_meta[chunk] = ExecutionChunkResult(chunk_meta, object_ref) logger.info("Waiting for stage %s complete.", stage_id) ray.wait(output_object_refs, fetch_local=False) logger.info("Stage %s is complete.", stage_id) - return chunk_to_result + return chunk_to_meta async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is None: diff --git a/mars/services/task/supervisor/manager.py b/mars/services/task/supervisor/manager.py index 62bbc6982e..871e184af0 100644 --- a/mars/services/task/supervisor/manager.py +++ b/mars/services/task/supervisor/manager.py @@ -22,9 +22,7 @@ from .... import oscar as mo from ....core import TileableGraph, TileableType, enter_mode, TileContext -from ....core.context import set_context from ....core.operand import Fetch -from ...context import ThreadedServiceContext from ...subtask import SubtaskResult, SubtaskGraph from ..config import task_options from ..core import Task, new_task_id, TaskStatus @@ -106,21 +104,10 @@ async def __post_create__(self): ) self._task_preprocessor_cls = self._get_task_preprocessor_cls() - # init context - await self._init_context() - async def __pre_destroy__(self): for processor_ref in self._task_id_to_processor_ref.values(): await processor_ref.destroy() - async def _init_context(self): - loop = asyncio.get_running_loop() - context = ThreadedServiceContext( - self._session_id, self.address, self.address, self.address, loop=loop - ) - await context.init() - set_context(context) - @staticmethod def gen_uid(session_id): return f"{session_id}_task_manager" diff --git a/mars/services/task/supervisor/tests/test_task_manager.py b/mars/services/task/supervisor/tests/test_task_manager.py index c0fd329a2b..23dd84af3b 100644 --- a/mars/services/task/supervisor/tests/test_task_manager.py +++ b/mars/services/task/supervisor/tests/test_task_manager.py @@ -550,9 +550,8 @@ async def test_numexpr(actor_pool): ) == [1] * len(result_tileable.chunks) -@pytest.mark.parametrize("config", [{"incremental_index": True}]) @pytest.mark.asyncio -async def test_optimization(actor_pool, config): +async def test_optimization(actor_pool): ( execution_backend, pool, @@ -576,7 +575,7 @@ async def test_optimization(actor_pool, config): ) pdf.to_csv(file_path, index=False) - df = md.read_csv(file_path, incremental_index=config["incremental_index"]) + df = md.read_csv(file_path, incremental_index=True) df2 = df.groupby("c").agg({"a": "sum"}) df3 = df[["b", "a"]] diff --git a/mars/services/task/supervisor/tests/test_task_manager_on_ray_dag.py b/mars/services/task/supervisor/tests/test_task_manager_on_ray_dag.py index 44ca5e1c10..a570baa7a1 100644 --- a/mars/services/task/supervisor/tests/test_task_manager_on_ray_dag.py +++ b/mars/services/task/supervisor/tests/test_task_manager_on_ray_dag.py @@ -57,10 +57,8 @@ async def test_numexpr(ray_start_regular_shared2, actor_pool): await test_task_manager.test_numexpr(actor_pool) -# TODO(fyrestone): Support incremental index in ray backend. @require_ray -@pytest.mark.parametrize("config", [{"incremental_index": False}]) @pytest.mark.parametrize("actor_pool", [{"backend": "ray"}], indirect=True) @pytest.mark.asyncio -async def test_optimization(ray_start_regular_shared2, actor_pool, config): - await test_task_manager.test_optimization(actor_pool, config) +async def test_optimization(ray_start_regular_shared2, actor_pool): + await test_task_manager.test_optimization(actor_pool) From 3a101f920a51b7278a6b6f6dea69b2c8bc352dfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 10:32:07 +0800 Subject: [PATCH 2/8] Fix stop pool --- mars/oscar/backends/pool.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mars/oscar/backends/pool.py b/mars/oscar/backends/pool.py index dce0f68bfb..97552d5e82 100644 --- a/mars/oscar/backends/pool.py +++ b/mars/oscar/backends/pool.py @@ -456,7 +456,9 @@ async def join(self, timeout: float = None): async def stop(self): try: # clean global router - Router.get_instance().remove_router(self._router) + router = Router.get_instance() + if router is not None: + router.remove_router(self._router) stop_tasks = [] # stop all servers stop_tasks.extend([server.stop() for server in self._servers]) From 57299611aa98ab366fa0839edd492b718070887a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 10:36:45 +0800 Subject: [PATCH 3/8] Not to fetch chunk meta when tiling HeadOptimizedDataSource --- mars/dataframe/datasource/core.py | 4 +-- mars/deploy/oscar/tests/test_local.py | 44 ++++++++++++++----------- mars/deploy/oscar/tests/test_ray_dag.py | 3 +- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/mars/dataframe/datasource/core.py b/mars/dataframe/datasource/core.py index 139f1c09c4..7f372d2212 100644 --- a/mars/dataframe/datasource/core.py +++ b/mars/dataframe/datasource/core.py @@ -58,9 +58,7 @@ def _tile_head(cls, op: "HeadOptimizedDataSource"): # execute first chunk yield chunks[:1] - ctx = get_context() - chunk_shape = ctx.get_chunks_meta([chunks[0].key], fields=["shape"])[0]["shape"] - + chunk_shape = chunks[0].shape if chunk_shape[0] == op.nrows: # the first chunk has enough data tileds[0]._nsplits = tuple((s,) for s in chunk_shape) diff --git a/mars/deploy/oscar/tests/test_local.py b/mars/deploy/oscar/tests/test_local.py index 97a4513b36..f6f147f553 100644 --- a/mars/deploy/oscar/tests/test_local.py +++ b/mars/deploy/oscar/tests/test_local.py @@ -490,7 +490,7 @@ async def test_web_session(create_cluster, config): ) -@pytest.mark.parametrize("config", [{"backend": "mars", "incremental_index": True}]) +@pytest.mark.parametrize("config", [{"backend": "mars"}]) def test_sync_execute(config): session = new_session( backend=config["backend"], n_cpu=2, web=False, use_uvloop=False @@ -518,25 +518,31 @@ def test_sync_execute(config): assert d is c assert abs(session.fetch(d) - raw.sum()) < 0.001 - # TODO(fyrestone): Remove this when the Ray backend support incremental index. - if config["incremental_index"]: - with tempfile.TemporaryDirectory() as tempdir: - file_path = os.path.join(tempdir, "test.csv") - pdf = pd.DataFrame( - np.random.RandomState(0).rand(100, 10), - columns=[f"col{i}" for i in range(10)], - ) - pdf.to_csv(file_path, index=False) - - df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5) - result = df.sum(axis=1).execute().fetch() - expected = pd.read_csv(file_path).sum(axis=1) - pd.testing.assert_series_equal(result, expected) + with tempfile.TemporaryDirectory() as tempdir: + file_path = os.path.join(tempdir, "test.csv") + pdf = pd.DataFrame( + np.random.RandomState(0).rand(100, 10), + columns=[f"col{i}" for i in range(10)], + ) + pdf.to_csv(file_path, index=False) - df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5) - result = df.head(10).execute().fetch() - expected = pd.read_csv(file_path).head(10) - pd.testing.assert_frame_equal(result, expected) + df = md.read_csv( + file_path, + chunk_bytes=os.stat(file_path).st_size / 5, + incremental_index=True, + ) + result = df.sum(axis=1).execute().fetch() + expected = pd.read_csv(file_path).sum(axis=1) + pd.testing.assert_series_equal(result, expected) + + df = md.read_csv( + file_path, + chunk_bytes=os.stat(file_path).st_size / 5, + incremental_index=True, + ) + result = df.head(10).execute().fetch() + expected = pd.read_csv(file_path).head(10) + pd.testing.assert_frame_equal(result, expected) for worker_pool in session._session.client._cluster._worker_pools: _assert_storage_cleaned( diff --git a/mars/deploy/oscar/tests/test_ray_dag.py b/mars/deploy/oscar/tests/test_ray_dag.py index 383ab1b16c..3fcaefd57b 100644 --- a/mars/deploy/oscar/tests/test_ray_dag.py +++ b/mars/deploy/oscar/tests/test_ray_dag.py @@ -112,8 +112,7 @@ async def test_iterative_tiling(ray_start_regular_shared2, create_cluster): await test_local.test_iterative_tiling(create_cluster) -# TODO(fyrestone): Support incremental index in ray backend. @require_ray -@pytest.mark.parametrize("config", [{"backend": "ray", "incremental_index": False}]) +@pytest.mark.parametrize("config", [{"backend": "ray"}]) def test_sync_execute(config): test_local.test_sync_execute(config) From 600ff0444b612932d6199e90888d11529792826e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 10:58:37 +0800 Subject: [PATCH 4/8] Fix --- mars/services/task/execution/ray/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mars/services/task/execution/ray/context.py b/mars/services/task/execution/ray/context.py index ba74746900..2a83d285fd 100644 --- a/mars/services/task/execution/ray/context.py +++ b/mars/services/task/execution/ray/context.py @@ -42,7 +42,7 @@ async def call_remote_object(self, name: str, attr: str, *args, **kwargs): return await async_meth(*args, **kwargs) @staticmethod - @functools.lru_cache + @functools.lru_cache(100) def _sync_to_async(func): if inspect.iscoroutinefunction(func): return func From 96efb6c6e67952d224112b7a9ede96c1db4c7fdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 11:33:30 +0800 Subject: [PATCH 5/8] Fix --- .../task/execution/ray/tests/test_ray_execution_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py index da42206426..aa3a810afa 100644 --- a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py +++ b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py @@ -43,11 +43,11 @@ def test_ray_execute_subtask_basic(): subtask_id = new_task_id() subtask_chunk_graph = _gen_subtask_chunk_graph(b) - r = execute_subtask(subtask_id, serialize(subtask_chunk_graph), set(), []) + r = execute_subtask(None, subtask_id, serialize(subtask_chunk_graph), set(), []) np.testing.assert_array_equal(r, raw_expect) test_get_meta_chunk = subtask_chunk_graph.result_chunks[0] r = execute_subtask( - subtask_id, serialize(subtask_chunk_graph), {test_get_meta_chunk.key}, [] + None, subtask_id, serialize(subtask_chunk_graph), {test_get_meta_chunk.key}, [] ) assert len(r) == 2 meta_dict, r = r From d62ad34a97d199f24d663001e91361ce703649c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 18:02:51 +0800 Subject: [PATCH 6/8] Use named actor for Ray task state --- mars/services/task/execution/ray/context.py | 29 +++++++++++---- mars/services/task/execution/ray/executor.py | 37 ++++++++++++------- .../ray/tests/test_ray_execution_backend.py | 4 +- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/mars/services/task/execution/ray/context.py b/mars/services/task/execution/ray/context.py index 2a83d285fd..3b458d5351 100644 --- a/mars/services/task/execution/ray/context.py +++ b/mars/services/task/execution/ray/context.py @@ -14,6 +14,7 @@ import functools import inspect +from typing import Union from .....core.context import Context from .....utils import implements, lazy_import @@ -70,24 +71,36 @@ def wrap(*args, **kwargs): class _RayRemoteObjectContext: - def __init__(self, task_state_actor: "ray.actor.ActorHandle", *args, **kwargs): + def __init__( + self, actor_name_or_handle: Union[str, "ray.actor.ActorHandle"], *args, **kwargs + ): super().__init__(*args, **kwargs) - self._task_state_actor = task_state_actor + self._actor_name_or_handle = actor_name_or_handle + self._task_state_actor = None + + def _get_task_state_actor(self) -> "ray.actor.ActorHandle": + if self._task_state_actor is None: + if isinstance(self._actor_name_or_handle, ray.actor.ActorHandle): + self._task_state_actor = self._actor_name_or_handle + else: + self._task_state_actor = ray.get_actor(self._actor_name_or_handle) + return self._task_state_actor @implements(Context.create_remote_object) def create_remote_object(self, name: str, object_cls, *args, **kwargs): - self._task_state_actor.create_remote_object.remote( - name, object_cls, *args, **kwargs - ) - return _RayRemoteObjectWrapper(self._task_state_actor, name) + task_state_actor = self._get_task_state_actor() + task_state_actor.create_remote_object.remote(name, object_cls, *args, **kwargs) + return _RayRemoteObjectWrapper(task_state_actor, name) @implements(Context.get_remote_object) def get_remote_object(self, name: str): - return _RayRemoteObjectWrapper(self._task_state_actor, name) + task_state_actor = self._get_task_state_actor() + return _RayRemoteObjectWrapper(task_state_actor, name) @implements(Context.destroy_remote_object) def destroy_remote_object(self, name: str): - self._task_state_actor.destroy_remote_object.remote(name) + task_state_actor = self._get_task_state_actor() + task_state_actor.destroy_remote_object.remote(name) # TODO(fyrestone): Implement more APIs for Ray. diff --git a/mars/services/task/execution/ray/executor.py b/mars/services/task/execution/ray/executor.py index bf2e56d0c6..8c86758dae 100644 --- a/mars/services/task/execution/ray/executor.py +++ b/mars/services/task/execution/ray/executor.py @@ -39,6 +39,7 @@ from ....meta.api import MetaAPI from ....subtask import Subtask, SubtaskGraph from ....subtask.utils import iter_input_data_keys, iter_output_data +from ...core import Task from ..api import ( TaskExecutor, ExecutionConfig, @@ -56,11 +57,13 @@ class RayTaskState(RayRemoteObjectManager): - pass + @classmethod + def gen_name(cls, task_id: str): + return f"{cls.__name__}_{task_id}" def execute_subtask( - task_state_actor: "ray.actor.ActorHandle", + task_id: str, subtask_id: str, subtask_chunk_graph: ChunkGraph, output_meta_keys: Set[str], @@ -71,7 +74,9 @@ def execute_subtask( ensure_coverage() subtask_chunk_graph = deserialize(*subtask_chunk_graph) # inputs = [i[1] for i in inputs] - context = RayExecutionWorkerContext(task_state_actor, zip(input_keys, inputs)) + context = RayExecutionWorkerContext( + RayTaskState.gen_name(task_id), zip(input_keys, inputs) + ) # optimize chunk graph. subtask_chunk_graph = optimize(subtask_chunk_graph) # from data_key to results @@ -106,23 +111,23 @@ class RayTaskExecutor(TaskExecutor): def __init__( self, config: ExecutionConfig, - task, - tile_context, - ray_executor, - lifecycle_api, - meta_api, - task_state_actor, + task: Task, + tile_context: TileContext, + ray_executor: "ray.remote_function.RemoteFunction", + task_state_actor: "ray.actor.ActorHandle", + lifecycle_api: LifecycleAPI, + meta_api: MetaAPI, ): self._config = config self._task = task self._tile_context = tile_context self._ray_executor = ray_executor + self._task_state_actor = task_state_actor # api self._lifecycle_api = lifecycle_api self._meta_api = meta_api - self._task_state_actor = task_state_actor self._task_context = {} self._available_band_resources = None @@ -133,22 +138,26 @@ async def create( *, session_id: str, address: str, - task, + task: Task, tile_context: TileContext, **kwargs, ) -> "TaskExecutor": ray_executor = ray.remote(execute_subtask) lifecycle_api, meta_api = await cls._get_apis(session_id, address) - task_state_actor = ray.remote(RayTaskState).remote() + task_state_actor = ( + ray.remote(RayTaskState) + .options(name=RayTaskState.gen_name(task.task_id)) + .remote() + ) await cls._init_context(task_state_actor, session_id, address) return cls( config, task, tile_context, ray_executor, + task_state_actor, lifecycle_api, meta_api, - task_state_actor, ) @classmethod @@ -199,7 +208,7 @@ async def execute_subtask_graph( output_object_refs = self._ray_executor.options( num_returns=output_count ).remote( - self._task_state_actor, + subtask.task_id, subtask.subtask_id, serialize(subtask_chunk_graph), output_meta_keys, diff --git a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py index aa3a810afa..0944ab4b92 100644 --- a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py +++ b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py @@ -43,11 +43,11 @@ def test_ray_execute_subtask_basic(): subtask_id = new_task_id() subtask_chunk_graph = _gen_subtask_chunk_graph(b) - r = execute_subtask(None, subtask_id, serialize(subtask_chunk_graph), set(), []) + r = execute_subtask("", subtask_id, serialize(subtask_chunk_graph), set(), []) np.testing.assert_array_equal(r, raw_expect) test_get_meta_chunk = subtask_chunk_graph.result_chunks[0] r = execute_subtask( - None, subtask_id, serialize(subtask_chunk_graph), {test_get_meta_chunk.key}, [] + "", subtask_id, serialize(subtask_chunk_graph), {test_get_meta_chunk.key}, [] ) assert len(r) == 2 meta_dict, r = r From cff2da90913449eff873c6c27cc3ef2c61f1841f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 18:19:12 +0800 Subject: [PATCH 7/8] Improve coverage --- .../ray/tests/test_ray_execution_backend.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py index 0944ab4b92..627b4fe48d 100644 --- a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py +++ b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py @@ -23,6 +23,11 @@ from ......tests.core import require_ray from ......utils import lazy_import, get_chunk_params from ....core import new_task_id +from ..context import ( + RayRemoteObjectManager, + _RayRemoteObjectContext, + _RayRemoteObjectWrapper, +) from ..executor import execute_subtask from ..fetcher import RayFetcher @@ -79,3 +84,42 @@ async def test_ray_fetcher(ray_start_regular_shared2): np.testing.assert_array_equal(results[1], np_value) pd.testing.assert_frame_equal(results[2], pd_value.iloc[[1, 3]]) np.testing.assert_array_equal(results[3], np_value[[1, 3]]) + + +@require_ray +@pytest.mark.asyncio +async def test_ray_remote_object(ray_start_regular_shared2): + class _TestRemoteObject: + def __init__(self, i): + self._i = i + + def foo(self, a, b): + return self._i + a + b + + async def bar(self, a, b): + return self._i * a * b + + # Test RayRemoteObjectManager + name = "abc" + manager = RayRemoteObjectManager() + manager.create_remote_object(name, _TestRemoteObject, 2) + r = await manager.call_remote_object(name, "foo", 3, 4) + assert r == 9 + r = await manager.call_remote_object(name, "bar", 3, 4) + assert r == 24 + manager.destroy_remote_object(name) + with pytest.raises(KeyError): + await manager.call_remote_object(name, "foo", 3, 4) + + # Test _RayRemoteObjectContext + remote_manager = ray.remote(RayRemoteObjectManager).remote() + context = _RayRemoteObjectContext(remote_manager) + context.create_remote_object(name, _TestRemoteObject, 2) + remote_object = context.get_remote_object(name) + r = remote_object.foo(3, 4) + assert r == 9 + r = remote_object.bar(3, 4) + assert r == 24 + context.destroy_remote_object(name) + with pytest.raises(KeyError): + remote_object.foo(3, 4) From 21749fb2306c1b93c5d929c28a85fa1566e52802 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=AE=9D?= Date: Fri, 6 May 2022 18:48:17 +0800 Subject: [PATCH 8/8] Fix lint --- .../task/execution/ray/tests/test_ray_execution_backend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py index 627b4fe48d..2ae52871a6 100644 --- a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py +++ b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py @@ -23,11 +23,7 @@ from ......tests.core import require_ray from ......utils import lazy_import, get_chunk_params from ....core import new_task_id -from ..context import ( - RayRemoteObjectManager, - _RayRemoteObjectContext, - _RayRemoteObjectWrapper, -) +from ..context import RayRemoteObjectManager, _RayRemoteObjectContext from ..executor import execute_subtask from ..fetcher import RayFetcher