Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray] Ray execution state #3002

Merged
merged 9 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mars/dataframe/datasource/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def _tile_head(cls, op: "HeadOptimizedDataSource"):
# execute first chunk
yield chunks[:1]

ctx = get_context()
chunk_shape = ctx.get_chunks_meta([chunks[0].key], fields=["shape"])[0]["shape"]

chunk_shape = chunks[0].shape
zhongchun marked this conversation as resolved.
Show resolved Hide resolved
if chunk_shape[0] == op.nrows:
# the first chunk has enough data
tileds[0]._nsplits = tuple((s,) for s in chunk_shape)
Expand Down
44 changes: 25 additions & 19 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ async def test_web_session(create_cluster, config):
)


@pytest.mark.parametrize("config", [{"backend": "mars", "incremental_index": True}])
@pytest.mark.parametrize("config", [{"backend": "mars"}])
def test_sync_execute(config):
session = new_session(
backend=config["backend"], n_cpu=2, web=False, use_uvloop=False
Expand Down Expand Up @@ -518,25 +518,31 @@ def test_sync_execute(config):
assert d is c
assert abs(session.fetch(d) - raw.sum()) < 0.001

# TODO(fyrestone): Remove this when the Ray backend support incremental index.
if config["incremental_index"]:
with tempfile.TemporaryDirectory() as tempdir:
file_path = os.path.join(tempdir, "test.csv")
pdf = pd.DataFrame(
np.random.RandomState(0).rand(100, 10),
columns=[f"col{i}" for i in range(10)],
)
pdf.to_csv(file_path, index=False)

df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
result = df.sum(axis=1).execute().fetch()
expected = pd.read_csv(file_path).sum(axis=1)
pd.testing.assert_series_equal(result, expected)
with tempfile.TemporaryDirectory() as tempdir:
file_path = os.path.join(tempdir, "test.csv")
pdf = pd.DataFrame(
np.random.RandomState(0).rand(100, 10),
columns=[f"col{i}" for i in range(10)],
)
pdf.to_csv(file_path, index=False)

df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
result = df.head(10).execute().fetch()
expected = pd.read_csv(file_path).head(10)
pd.testing.assert_frame_equal(result, expected)
df = md.read_csv(
file_path,
chunk_bytes=os.stat(file_path).st_size / 5,
incremental_index=True,
)
result = df.sum(axis=1).execute().fetch()
expected = pd.read_csv(file_path).sum(axis=1)
pd.testing.assert_series_equal(result, expected)

df = md.read_csv(
file_path,
chunk_bytes=os.stat(file_path).st_size / 5,
incremental_index=True,
)
result = df.head(10).execute().fetch()
expected = pd.read_csv(file_path).head(10)
pd.testing.assert_frame_equal(result, expected)

for worker_pool in session._session.client._cluster._worker_pools:
_assert_storage_cleaned(
Expand Down
3 changes: 1 addition & 2 deletions mars/deploy/oscar/tests/test_ray_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ async def test_iterative_tiling(ray_start_regular_shared2, create_cluster):
await test_local.test_iterative_tiling(create_cluster)


# TODO(fyrestone): Support incremental index in ray backend.
@require_ray
@pytest.mark.parametrize("config", [{"backend": "ray", "incremental_index": False}])
@pytest.mark.parametrize("config", [{"backend": "ray"}])
def test_sync_execute(config):
test_local.test_sync_execute(config)
4 changes: 3 additions & 1 deletion mars/oscar/backends/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ async def join(self, timeout: float = None):
async def stop(self):
try:
# clean global router
Router.get_instance().remove_router(self._router)
router = Router.get_instance()
if router is not None:
router.remove_router(self._router)
stop_tasks = []
# stop all servers
stop_tasks.extend([server.stop() for server in self._servers])
Expand Down
12 changes: 12 additions & 0 deletions mars/services/task/execution/mars/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ..... import oscar as mo
from .....core import ChunkGraph, TileContext
from .....core.context import set_context
from .....core.operand import (
Fetch,
MapReduceOperand,
Expand All @@ -33,6 +34,7 @@
from .....resource import Resource
from .....typing import TileableType, BandType
from .....utils import Timer
from ....context import ThreadedServiceContext
from ....cluster.api import ClusterAPI
from ....lifecycle.api import LifecycleAPI
from ....meta.api import MetaAPI
Expand Down Expand Up @@ -121,6 +123,7 @@ async def create(
task_id=task.task_id,
cluster_api=cluster_api,
)
await cls._init_context(session_id, address)
return cls(
config,
task,
Expand All @@ -142,6 +145,15 @@ async def _get_apis(cls, session_id: str, address: str):
MetaAPI.create(session_id, address),
)

@classmethod
async def _init_context(cls, session_id: str, address: str):
loop = asyncio.get_running_loop()
context = ThreadedServiceContext(
session_id, address, address, address, loop=loop
)
await context.init()
set_context(context)

async def __aenter__(self):
profiling = ProfilingData[self._task.task_id, "general"]
# incref fetch tileables to ensure fetch data not deleted
Expand Down
103 changes: 101 additions & 2 deletions mars/services/task/execution/ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,108 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
from typing import Union

from .....core.context import Context
from .....utils import implements, lazy_import
from ....context import ThreadedServiceContext

ray = lazy_import("ray")


class RayRemoteObjectManager:
"""The remote object manager in task state actor."""

def __init__(self):
self._named_remote_objects = {}

def create_remote_object(self, name: str, object_cls, *args, **kwargs):
remote_object = object_cls(*args, **kwargs)
self._named_remote_objects[name] = remote_object

def destroy_remote_object(self, name: str):
self._named_remote_objects.pop(name, None)

async def call_remote_object(self, name: str, attr: str, *args, **kwargs):
remote_object = self._named_remote_objects[name]
meth = getattr(remote_object, attr)
async_meth = self._sync_to_async(meth)
return await async_meth(*args, **kwargs)

@staticmethod
@functools.lru_cache(100)
def _sync_to_async(func):
if inspect.iscoroutinefunction(func):
return func
else:

async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)

return async_wrapper


class _RayRemoteObjectWrapper:
def __init__(self, task_state_actor: "ray.actor.ActorHandle", name: str):
self._task_state_actor = task_state_actor
self._name = name

def __getattr__(self, attr):
def wrap(*args, **kwargs):
r = self._task_state_actor.call_remote_object.remote(
self._name, attr, *args, **kwargs
)
return ray.get(r)

return wrap


class _RayRemoteObjectContext:
def __init__(
self, actor_name_or_handle: Union[str, "ray.actor.ActorHandle"], *args, **kwargs
):
super().__init__(*args, **kwargs)
self._actor_name_or_handle = actor_name_or_handle
self._task_state_actor = None

def _get_task_state_actor(self) -> "ray.actor.ActorHandle":
if self._task_state_actor is None:
if isinstance(self._actor_name_or_handle, ray.actor.ActorHandle):
self._task_state_actor = self._actor_name_or_handle
else:
self._task_state_actor = ray.get_actor(self._actor_name_or_handle)
return self._task_state_actor

@implements(Context.create_remote_object)
def create_remote_object(self, name: str, object_cls, *args, **kwargs):
task_state_actor = self._get_task_state_actor()
task_state_actor.create_remote_object.remote(name, object_cls, *args, **kwargs)
return _RayRemoteObjectWrapper(task_state_actor, name)

@implements(Context.get_remote_object)
def get_remote_object(self, name: str):
task_state_actor = self._get_task_state_actor()
return _RayRemoteObjectWrapper(task_state_actor, name)

@implements(Context.destroy_remote_object)
def destroy_remote_object(self, name: str):
task_state_actor = self._get_task_state_actor()
task_state_actor.destroy_remote_object.remote(name)


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext):
"""The context for tiling."""
zhongchun marked this conversation as resolved.
Show resolved Hide resolved

pass


# TODO(fyrestone): Implement more APIs for Ray.
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
"""The context for executing operands."""

# TODO(fyrestone): Should implement the mars.core.context.Context.
class RayExecutionContext(dict):
@staticmethod
def new_custom_log_dir():
return None
Loading