Skip to content

Commit

Permalink
[Ray] Implements get_chunks_meta for Ray execution context (mars-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone authored and wjsi committed May 24, 2022
1 parent 4512041 commit fe409c3
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 61 deletions.
1 change: 1 addition & 0 deletions mars/dataframe/base/tests/test_base_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def test_datetime_method_execution(setup):
pd.testing.assert_series_equal(result, expected)


@pytest.mark.ray_dag
def test_isin_execution(setup):
# one chunk in multiple chunks
a = pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Expand Down
58 changes: 55 additions & 3 deletions mars/services/task/execution/ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import functools
import inspect
import logging
from dataclasses import asdict
from typing import Union, Dict, List

from .....core.context import Context
from .....storage.base import StorageLevel
from .....utils import implements, lazy_import
from ....context import ThreadedServiceContext

Expand Down Expand Up @@ -116,9 +118,10 @@ def destroy_remote_object(self, name: str):
class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext):
"""The context for tiling."""

def __init__(self, task_context: Dict, *args, **kwargs):
def __init__(self, task_context: Dict, task_chunks_meta: Dict, *args, **kwargs):
super().__init__(*args, **kwargs)
self._task_context = task_context
self._task_chunks_meta = task_chunks_meta

@implements(Context.get_chunks_result)
def get_chunks_result(self, data_keys: List[str]) -> List:
Expand All @@ -128,11 +131,60 @@ def get_chunks_result(self, data_keys: List[str]) -> List:
logger.info("Got %s chunks result.", len(result))
return result

@implements(Context.get_chunks_meta)
def get_chunks_meta(
self, data_keys: List[str], fields: List[str] = None, error="raise"
) -> List[Dict]:
result = []
# TODO(fyrestone): Support get_chunks_meta from meta service if needed.
for key in data_keys:
chunk_meta = self._task_chunks_meta[key]
meta = asdict(chunk_meta)
meta = {f: meta.get(f) for f in fields}
result.append(meta)
return result


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
"""The context for executing operands."""

@staticmethod
def new_custom_log_dir():
@classmethod
@implements(Context.new_custom_log_dir)
def new_custom_log_dir(cls):
logger.info(
"%s does not support register_custom_log_path / new_custom_log_dir",
cls.__name__,
)
return None

@staticmethod
@implements(Context.register_custom_log_path)
def register_custom_log_path(
session_id: str,
tileable_op_key: str,
chunk_op_key: str,
worker_address: str,
log_path: str,
):
raise NotImplementedError

@classmethod
@implements(Context.set_progress)
def set_progress(cls, progress: float):
logger.info(
"%s does not support set_running_operand_key / set_progress", cls.__name__
)

@staticmethod
@implements(Context.set_running_operand_key)
def set_running_operand_key(session_id: str, op_key: str):
raise NotImplementedError

@classmethod
@implements(Context.get_storage_info)
def get_storage_info(
cls, address: str = None, level: StorageLevel = StorageLevel.MEMORY
):
logger.info("%s does not support get_storage_info", cls.__name__)
return {}
99 changes: 63 additions & 36 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import functools
import logging
from dataclasses import dataclass
from typing import List, Dict, Any, Set
from .....core import ChunkGraph, Chunk, TileContext
from .....core.context import set_context
Expand All @@ -30,6 +31,7 @@
from .....serialization import serialize, deserialize
from .....typing import BandType
from .....utils import (
calc_data_size,
lazy_import,
get_chunk_params,
get_chunk_key_to_data_keys,
Expand All @@ -56,6 +58,11 @@
logger = logging.getLogger(__name__)


@dataclass
class _RayChunkMeta:
memory_size: int


class RayTaskState(RayRemoteObjectManager):
@classmethod
def gen_name(cls, task_id: str):
Expand Down Expand Up @@ -102,11 +109,14 @@ def execute_subtask(
if output_meta_keys:
output_meta = {}
for chunk in subtask_chunk_graph.result_chunks:
if chunk.key in output_meta_keys:
chunk_key = chunk.key
if chunk_key in output_meta_keys and chunk_key not in output_meta:
if isinstance(chunk.op, Fuse):
# fuse op
chunk = chunk.chunk
output_meta[chunk.key] = get_chunk_params(chunk)
data = context[chunk_key]
memory_size = calc_data_size(data)
output_meta[chunk_key] = get_chunk_params(chunk), memory_size
assert len(output_meta_keys) == len(output_meta)
output_values.append(output_meta)
output_values.extend(output.values())
Expand All @@ -125,6 +135,7 @@ def __init__(
task: Task,
tile_context: TileContext,
task_context: Dict[str, "ray.ObjectRef"],
task_chunks_meta: Dict[str, _RayChunkMeta],
task_state_actor: "ray.actor.ActorHandle",
lifecycle_api: LifecycleAPI,
meta_api: MetaAPI,
Expand All @@ -133,6 +144,7 @@ def __init__(
self._task = task
self._tile_context = tile_context
self._task_context = task_context
self._task_chunks_meta = task_chunks_meta
self._task_state_actor = task_state_actor
self._ray_executor = self._get_ray_executor()

Expand Down Expand Up @@ -166,12 +178,16 @@ async def create(
.remote()
)
task_context = {}
await cls._init_context(task_context, task_state_actor, session_id, address)
task_chunks_meta = {}
await cls._init_context(
task_context, task_chunks_meta, task_state_actor, session_id, address
)
return cls(
config,
task,
tile_context,
task_context,
task_chunks_meta,
task_state_actor,
lifecycle_api,
meta_api,
Expand All @@ -183,6 +199,7 @@ def destroy(self):
self._task = None
self._tile_context = None
self._task_context = None
self._task_chunks_meta = None
self._task_state_actor = None
self._ray_executor = None

Expand All @@ -207,7 +224,7 @@ async def _get_apis(cls, session_id: str, address: str):
)

@staticmethod
@functools.lru_cache(maxsize=1)
@functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster
def _get_ray_executor():
# Export remote function once.
return ray.remote(execute_subtask)
Expand All @@ -216,13 +233,15 @@ def _get_ray_executor():
async def _init_context(
cls,
task_context: Dict[str, "ray.ObjectRef"],
task_chunks_meta: Dict[str, _RayChunkMeta],
task_state_actor: "ray.actor.ActorHandle",
session_id: str,
address: str,
):
loop = asyncio.get_running_loop()
context = RayExecutionContext(
task_context,
task_chunks_meta,
task_state_actor,
session_id,
address,
Expand Down Expand Up @@ -293,7 +312,9 @@ async def execute_subtask_graph(
logger.info("Getting %s metas of stage %s.", meta_count, stage_id)
meta_list = await asyncio.gather(*output_meta_object_refs)
for meta in meta_list:
key_to_meta.update(meta)
for key, (params, memory_size) in meta.items():
key_to_meta[key] = params
self._task_chunks_meta[key] = _RayChunkMeta(memory_size=memory_size)
assert len(key_to_meta) == len(result_meta_keys)
logger.info("Got %s metas of stage %s.", meta_count, stage_id)

Expand All @@ -304,9 +325,9 @@ async def execute_subtask_graph(
chunk_key = chunk.key
object_ref = task_context[chunk_key]
output_object_refs.add(object_ref)
chunk_meta = key_to_meta.get(chunk_key)
if chunk_meta is not None:
chunk_to_meta[chunk] = ExecutionChunkResult(chunk_meta, object_ref)
chunk_params = key_to_meta.get(chunk_key)
if chunk_params is not None:
chunk_to_meta[chunk] = ExecutionChunkResult(chunk_params, object_ref)

logger.info("Waiting for stage %s complete.", stage_id)
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
Expand All @@ -319,36 +340,42 @@ async def execute_subtask_graph(
return chunk_to_meta

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
tileable_keys = []
update_metas = []
update_lifecycles = []
for tileable in self._task.tileable_graph.result_tileables:
tileable_keys.append(tileable.key)
tileable = tileable.data if hasattr(tileable, "data") else tileable
chunk_keys = []
for chunk in self._tile_context[tileable].chunks:
chunk_keys.append(chunk.key)
if chunk.key in self._task_context:
# Some tileable graph may have result chunks that not be executed,
# for example:
# r, b = cut(series, bins, retbins=True)
# r_result = r.execute().fetch()
# b_result = b.execute().fetch() <- This is the case
object_ref = self._task_context[chunk.key]
update_metas.append(
self._meta_api.set_chunk_meta.delay(
chunk,
bands=[],
object_ref=object_ref,
)
if exc_type is not None:
return

# Update info if no exception occurs.
tileable_keys = []
update_metas = []
update_lifecycles = []
for tileable in self._task.tileable_graph.result_tileables:
tileable_keys.append(tileable.key)
tileable = tileable.data if hasattr(tileable, "data") else tileable
chunk_keys = []
for chunk in self._tile_context[tileable].chunks:
chunk_key = chunk.key
chunk_keys.append(chunk_key)
if chunk_key in self._task_context:
# Some tileable graph may have result chunks that not be executed,
# for example:
# r, b = cut(series, bins, retbins=True)
# r_result = r.execute().fetch()
# b_result = b.execute().fetch() <- This is the case
object_ref = self._task_context[chunk_key]
chunk_meta = self._task_chunks_meta[chunk_key]
update_metas.append(
self._meta_api.set_chunk_meta.delay(
chunk,
bands=[],
object_ref=object_ref,
memory_size=chunk_meta.memory_size,
)
update_lifecycles.append(
self._lifecycle_api.track.delay(tileable.key, chunk_keys)
)
await self._meta_api.set_chunk_meta.batch(*update_metas)
await self._lifecycle_api.track.batch(*update_lifecycles)
await self._lifecycle_api.incref_tileables(tileable_keys)
update_lifecycles.append(
self._lifecycle_api.track.delay(tileable.key, chunk_keys)
)
await self._meta_api.set_chunk_meta.batch(*update_metas)
await self._lifecycle_api.track.batch(*update_lifecycles)
await self._lifecycle_api.incref_tileables(tileable_keys)

async def get_available_band_resources(self) -> Dict[BandType, Resource]:
if self._available_band_resources is None:
Expand Down
35 changes: 15 additions & 20 deletions mars/services/task/execution/ray/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import functools
from collections import namedtuple
from typing import Dict, List

Expand All @@ -23,17 +24,28 @@
_FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"])


def _query_object_with_condition(o, conditions):
try:
return o.iloc[conditions]
except AttributeError:
return o[conditions]


@register_fetcher_cls
class RayFetcher(Fetcher):
name = "ray"
required_meta_keys = ("object_refs",)

def __init__(self, **kwargs):
_make_query_function_remote()

self._fetch_info_list = []
self._no_conditions = True

@staticmethod
@functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster
def _remote_query_object_with_condition():
# Export remote function once.
return ray.remote(_query_object_with_condition)

async def append(self, chunk_key: str, chunk_meta: Dict, conditions: List = None):
if conditions is not None:
self._no_conditions = False
Expand All @@ -51,24 +63,7 @@ async def get(self):
if fetch_info.conditions is None:
refs[index] = fetch_info.object_ref
else:
refs[index] = _remote_query_object_with_condition.remote(
refs[index] = self._remote_query_object_with_condition().remote(
fetch_info.object_ref, fetch_info.conditions
)
return await asyncio.gather(*refs)


def _query_object_with_condition(o, conditions):
try:
return o.iloc[conditions]
except AttributeError:
return o[conditions]


_remote_query_object_with_condition = None


def _make_query_function_remote():
global _remote_query_object_with_condition

if _remote_query_object_with_condition is None:
_remote_query_object_with_condition = ray.remote(_query_object_with_condition)
Loading

0 comments on commit fe409c3

Please sign in to comment.