diff --git a/mars/dataframe/base/tests/test_base_execution.py b/mars/dataframe/base/tests/test_base_execution.py index 541a4b9223..4390962b75 100644 --- a/mars/dataframe/base/tests/test_base_execution.py +++ b/mars/dataframe/base/tests/test_base_execution.py @@ -682,6 +682,7 @@ def test_datetime_method_execution(setup): pd.testing.assert_series_equal(result, expected) +@pytest.mark.ray_dag def test_isin_execution(setup): # one chunk in multiple chunks a = pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) diff --git a/mars/services/task/execution/ray/context.py b/mars/services/task/execution/ray/context.py index dde28914a8..80d96f5071 100644 --- a/mars/services/task/execution/ray/context.py +++ b/mars/services/task/execution/ray/context.py @@ -15,9 +15,11 @@ import functools import inspect import logging +from dataclasses import asdict from typing import Union, Dict, List from .....core.context import Context +from .....storage.base import StorageLevel from .....utils import implements, lazy_import from ....context import ThreadedServiceContext @@ -116,9 +118,10 @@ def destroy_remote_object(self, name: str): class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext): """The context for tiling.""" - def __init__(self, task_context: Dict, *args, **kwargs): + def __init__(self, task_context: Dict, task_chunks_meta: Dict, *args, **kwargs): super().__init__(*args, **kwargs) self._task_context = task_context + self._task_chunks_meta = task_chunks_meta @implements(Context.get_chunks_result) def get_chunks_result(self, data_keys: List[str]) -> List: @@ -128,11 +131,60 @@ def get_chunks_result(self, data_keys: List[str]) -> List: logger.info("Got %s chunks result.", len(result)) return result + @implements(Context.get_chunks_meta) + def get_chunks_meta( + self, data_keys: List[str], fields: List[str] = None, error="raise" + ) -> List[Dict]: + result = [] + # TODO(fyrestone): Support get_chunks_meta from meta service if needed. + for key in data_keys: + chunk_meta = self._task_chunks_meta[key] + meta = asdict(chunk_meta) + meta = {f: meta.get(f) for f in fields} + result.append(meta) + return result + # TODO(fyrestone): Implement more APIs for Ray. class RayExecutionWorkerContext(_RayRemoteObjectContext, dict): """The context for executing operands.""" - @staticmethod - def new_custom_log_dir(): + @classmethod + @implements(Context.new_custom_log_dir) + def new_custom_log_dir(cls): + logger.info( + "%s does not support register_custom_log_path / new_custom_log_dir", + cls.__name__, + ) return None + + @staticmethod + @implements(Context.register_custom_log_path) + def register_custom_log_path( + session_id: str, + tileable_op_key: str, + chunk_op_key: str, + worker_address: str, + log_path: str, + ): + raise NotImplementedError + + @classmethod + @implements(Context.set_progress) + def set_progress(cls, progress: float): + logger.info( + "%s does not support set_running_operand_key / set_progress", cls.__name__ + ) + + @staticmethod + @implements(Context.set_running_operand_key) + def set_running_operand_key(session_id: str, op_key: str): + raise NotImplementedError + + @classmethod + @implements(Context.get_storage_info) + def get_storage_info( + cls, address: str = None, level: StorageLevel = StorageLevel.MEMORY + ): + logger.info("%s does not support get_storage_info", cls.__name__) + return {} diff --git a/mars/services/task/execution/ray/executor.py b/mars/services/task/execution/ray/executor.py index c4d714dbc8..7b52048ea4 100644 --- a/mars/services/task/execution/ray/executor.py +++ b/mars/services/task/execution/ray/executor.py @@ -15,6 +15,7 @@ import asyncio import functools import logging +from dataclasses import dataclass from typing import List, Dict, Any, Set from .....core import ChunkGraph, Chunk, TileContext from .....core.context import set_context @@ -30,6 +31,7 @@ from .....serialization import serialize, deserialize from .....typing import BandType from .....utils import ( + calc_data_size, lazy_import, get_chunk_params, get_chunk_key_to_data_keys, @@ -56,6 +58,11 @@ logger = logging.getLogger(__name__) +@dataclass +class _RayChunkMeta: + memory_size: int + + class RayTaskState(RayRemoteObjectManager): @classmethod def gen_name(cls, task_id: str): @@ -102,11 +109,14 @@ def execute_subtask( if output_meta_keys: output_meta = {} for chunk in subtask_chunk_graph.result_chunks: - if chunk.key in output_meta_keys: + chunk_key = chunk.key + if chunk_key in output_meta_keys and chunk_key not in output_meta: if isinstance(chunk.op, Fuse): # fuse op chunk = chunk.chunk - output_meta[chunk.key] = get_chunk_params(chunk) + data = context[chunk_key] + memory_size = calc_data_size(data) + output_meta[chunk_key] = get_chunk_params(chunk), memory_size assert len(output_meta_keys) == len(output_meta) output_values.append(output_meta) output_values.extend(output.values()) @@ -125,6 +135,7 @@ def __init__( task: Task, tile_context: TileContext, task_context: Dict[str, "ray.ObjectRef"], + task_chunks_meta: Dict[str, _RayChunkMeta], task_state_actor: "ray.actor.ActorHandle", lifecycle_api: LifecycleAPI, meta_api: MetaAPI, @@ -133,6 +144,7 @@ def __init__( self._task = task self._tile_context = tile_context self._task_context = task_context + self._task_chunks_meta = task_chunks_meta self._task_state_actor = task_state_actor self._ray_executor = self._get_ray_executor() @@ -166,12 +178,16 @@ async def create( .remote() ) task_context = {} - await cls._init_context(task_context, task_state_actor, session_id, address) + task_chunks_meta = {} + await cls._init_context( + task_context, task_chunks_meta, task_state_actor, session_id, address + ) return cls( config, task, tile_context, task_context, + task_chunks_meta, task_state_actor, lifecycle_api, meta_api, @@ -183,6 +199,7 @@ def destroy(self): self._task = None self._tile_context = None self._task_context = None + self._task_chunks_meta = None self._task_state_actor = None self._ray_executor = None @@ -207,7 +224,7 @@ async def _get_apis(cls, session_id: str, address: str): ) @staticmethod - @functools.lru_cache(maxsize=1) + @functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster def _get_ray_executor(): # Export remote function once. return ray.remote(execute_subtask) @@ -216,6 +233,7 @@ def _get_ray_executor(): async def _init_context( cls, task_context: Dict[str, "ray.ObjectRef"], + task_chunks_meta: Dict[str, _RayChunkMeta], task_state_actor: "ray.actor.ActorHandle", session_id: str, address: str, @@ -223,6 +241,7 @@ async def _init_context( loop = asyncio.get_running_loop() context = RayExecutionContext( task_context, + task_chunks_meta, task_state_actor, session_id, address, @@ -293,7 +312,9 @@ async def execute_subtask_graph( logger.info("Getting %s metas of stage %s.", meta_count, stage_id) meta_list = await asyncio.gather(*output_meta_object_refs) for meta in meta_list: - key_to_meta.update(meta) + for key, (params, memory_size) in meta.items(): + key_to_meta[key] = params + self._task_chunks_meta[key] = _RayChunkMeta(memory_size=memory_size) assert len(key_to_meta) == len(result_meta_keys) logger.info("Got %s metas of stage %s.", meta_count, stage_id) @@ -304,9 +325,9 @@ async def execute_subtask_graph( chunk_key = chunk.key object_ref = task_context[chunk_key] output_object_refs.add(object_ref) - chunk_meta = key_to_meta.get(chunk_key) - if chunk_meta is not None: - chunk_to_meta[chunk] = ExecutionChunkResult(chunk_meta, object_ref) + chunk_params = key_to_meta.get(chunk_key) + if chunk_params is not None: + chunk_to_meta[chunk] = ExecutionChunkResult(chunk_params, object_ref) logger.info("Waiting for stage %s complete.", stage_id) # Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py @@ -319,36 +340,42 @@ async def execute_subtask_graph( return chunk_to_meta async def __aexit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - tileable_keys = [] - update_metas = [] - update_lifecycles = [] - for tileable in self._task.tileable_graph.result_tileables: - tileable_keys.append(tileable.key) - tileable = tileable.data if hasattr(tileable, "data") else tileable - chunk_keys = [] - for chunk in self._tile_context[tileable].chunks: - chunk_keys.append(chunk.key) - if chunk.key in self._task_context: - # Some tileable graph may have result chunks that not be executed, - # for example: - # r, b = cut(series, bins, retbins=True) - # r_result = r.execute().fetch() - # b_result = b.execute().fetch() <- This is the case - object_ref = self._task_context[chunk.key] - update_metas.append( - self._meta_api.set_chunk_meta.delay( - chunk, - bands=[], - object_ref=object_ref, - ) + if exc_type is not None: + return + + # Update info if no exception occurs. + tileable_keys = [] + update_metas = [] + update_lifecycles = [] + for tileable in self._task.tileable_graph.result_tileables: + tileable_keys.append(tileable.key) + tileable = tileable.data if hasattr(tileable, "data") else tileable + chunk_keys = [] + for chunk in self._tile_context[tileable].chunks: + chunk_key = chunk.key + chunk_keys.append(chunk_key) + if chunk_key in self._task_context: + # Some tileable graph may have result chunks that not be executed, + # for example: + # r, b = cut(series, bins, retbins=True) + # r_result = r.execute().fetch() + # b_result = b.execute().fetch() <- This is the case + object_ref = self._task_context[chunk_key] + chunk_meta = self._task_chunks_meta[chunk_key] + update_metas.append( + self._meta_api.set_chunk_meta.delay( + chunk, + bands=[], + object_ref=object_ref, + memory_size=chunk_meta.memory_size, ) - update_lifecycles.append( - self._lifecycle_api.track.delay(tileable.key, chunk_keys) ) - await self._meta_api.set_chunk_meta.batch(*update_metas) - await self._lifecycle_api.track.batch(*update_lifecycles) - await self._lifecycle_api.incref_tileables(tileable_keys) + update_lifecycles.append( + self._lifecycle_api.track.delay(tileable.key, chunk_keys) + ) + await self._meta_api.set_chunk_meta.batch(*update_metas) + await self._lifecycle_api.track.batch(*update_lifecycles) + await self._lifecycle_api.incref_tileables(tileable_keys) async def get_available_band_resources(self) -> Dict[BandType, Resource]: if self._available_band_resources is None: diff --git a/mars/services/task/execution/ray/fetcher.py b/mars/services/task/execution/ray/fetcher.py index e638964ccb..11371b213b 100644 --- a/mars/services/task/execution/ray/fetcher.py +++ b/mars/services/task/execution/ray/fetcher.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import functools from collections import namedtuple from typing import Dict, List @@ -23,17 +24,28 @@ _FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"]) +def _query_object_with_condition(o, conditions): + try: + return o.iloc[conditions] + except AttributeError: + return o[conditions] + + @register_fetcher_cls class RayFetcher(Fetcher): name = "ray" required_meta_keys = ("object_refs",) def __init__(self, **kwargs): - _make_query_function_remote() - self._fetch_info_list = [] self._no_conditions = True + @staticmethod + @functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster + def _remote_query_object_with_condition(): + # Export remote function once. + return ray.remote(_query_object_with_condition) + async def append(self, chunk_key: str, chunk_meta: Dict, conditions: List = None): if conditions is not None: self._no_conditions = False @@ -51,24 +63,7 @@ async def get(self): if fetch_info.conditions is None: refs[index] = fetch_info.object_ref else: - refs[index] = _remote_query_object_with_condition.remote( + refs[index] = self._remote_query_object_with_condition().remote( fetch_info.object_ref, fetch_info.conditions ) return await asyncio.gather(*refs) - - -def _query_object_with_condition(o, conditions): - try: - return o.iloc[conditions] - except AttributeError: - return o[conditions] - - -_remote_query_object_with_condition = None - - -def _make_query_function_remote(): - global _remote_query_object_with_condition - - if _remote_query_object_with_condition is None: - _remote_query_object_with_condition = ray.remote(_query_object_with_condition) diff --git a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py index 48ead7fc1f..905c430417 100644 --- a/mars/services/task/execution/ray/tests/test_ray_execution_backend.py +++ b/mars/services/task/execution/ray/tests/test_ray_execution_backend.py @@ -30,6 +30,7 @@ from ..config import RayExecutionConfig from ..context import ( RayExecutionContext, + RayExecutionWorkerContext, RayRemoteObjectManager, _RayRemoteObjectContext, ) @@ -71,6 +72,7 @@ def test_ray_executor_destroy(): task=task, tile_context=TileContext(), task_context={}, + task_chunks_meta={}, task_state_actor=None, lifecycle_api=None, meta_api=None, @@ -104,7 +106,9 @@ def test_ray_execute_subtask_basic(): assert len(r) == 2 meta_dict, r = r assert len(meta_dict) == 1 - assert meta_dict[test_get_meta_chunk.key] == get_chunk_params(test_get_meta_chunk) + assert meta_dict[test_get_meta_chunk.key][0] == get_chunk_params( + test_get_meta_chunk + ) np.testing.assert_array_equal(r, raw_expect) @@ -191,6 +195,24 @@ def fake_init(self): pass with mock.patch.object(ThreadedServiceContext, "__init__", new=fake_init): - context = RayExecutionContext({"abc": o}, None) + context = RayExecutionContext({"abc": o}, {}, None) r = context.get_chunks_result(["abc"]) assert r == [value] + + +def test_ray_execution_worker_context(): + context = RayExecutionWorkerContext(None) + with pytest.raises(NotImplementedError): + context.set_running_operand_key("mock_session_id", "mock_op_key") + with pytest.raises(NotImplementedError): + context.register_custom_log_path( + "mock_session_id", + "mock_tileable_op_key", + "mock_chunk_op_key", + "mock_worker_address", + "mock_log_path", + ) + + assert context.set_progress(0.1) is None + assert context.new_custom_log_dir() is None + assert context.get_storage_info("mock_address") == {} diff --git a/mars/tensor/base/tests/test_base_execution.py b/mars/tensor/base/tests/test_base_execution.py index 03d9b97ab6..32260a2ea3 100644 --- a/mars/tensor/base/tests/test_base_execution.py +++ b/mars/tensor/base/tests/test_base_execution.py @@ -815,6 +815,7 @@ def test_tile_execution(setup): np.testing.assert_equal(res, expected) +@pytest.mark.ray_dag def test_isin_execution(setup): element = 2 * arange(4, chunk_size=1).reshape((2, 2)) test_elements = [1, 2, 4, 8]