diff --git a/mars/services/task/execution/ray/fetcher.py b/mars/services/task/execution/ray/fetcher.py index 4247034ed1..708e448a8d 100644 --- a/mars/services/task/execution/ray/fetcher.py +++ b/mars/services/task/execution/ray/fetcher.py @@ -15,9 +15,10 @@ import asyncio from collections import namedtuple from typing import Dict, List +from .....utils import lazy_import from ..api import Fetcher, register_fetcher_cls - +ray = lazy_import("ray") _FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"]) @@ -38,18 +39,27 @@ async def append(self, chunk_key: str, chunk_meta: Dict, conditions: List = None ) async def get(self): - objects = await asyncio.gather( - *(info.object_ref for info in self._fetch_info_list) - ) if self._no_conditions: - return objects - results = [] - for o, fetch_info in zip(objects, self._fetch_info_list): + return await asyncio.gather( + *(info.object_ref for info in self._fetch_info_list) + ) + refs = [None] * len(self._fetch_info_list) + for index, fetch_info in enumerate(self._fetch_info_list): if fetch_info.conditions is None: - results.append(o) + refs[index] = fetch_info.object_ref else: - try: - results.append(o.iloc[fetch_info.conditions]) - except AttributeError: - results.append(o[fetch_info.conditions]) - return results + refs[index] = query_object_with_condition.remote( + fetch_info.object_ref, fetch_info.conditions + ) + return await asyncio.gather(*refs) + + +def query_object_with_condition(o, conditions): + try: + return o.iloc[conditions] + except AttributeError: + return o[conditions] + + +if ray: + query_object_with_condition = ray.remote(query_object_with_condition)