Skip to content

Commit

Permalink
[BACKPORT] [Ray] Implement cancel method on Ray task executor (mars-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored Jun 6, 2022
1 parent 49c3b84 commit 080d6b8
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 15 deletions.
11 changes: 11 additions & 0 deletions mars/deploy/oscar/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@
os.path.dirname(os.path.abspath(__file__)), "config.yml"
)

# the default times to retry subtask.
DEFAULT_SUBTASK_MAX_RETRIES = 3
# the default time to cancel a subtask.
DEFAULT_SUBTASK_CANCEL_TIMEOUT = 5


def _load_config(config: Union[str, Dict] = None):
return load_config(config, default_config_file=DEFAULT_CONFIG_FILE)
Expand Down Expand Up @@ -167,6 +172,12 @@ def __init__(
n_cpu=self._n_cpu,
mem_bytes=self._mem_bytes,
cuda_devices=self._cuda_devices,
subtask_cancel_timeout=self._config.get("scheduling", {}).get(
"subtask_cancel_timeout", DEFAULT_SUBTASK_CANCEL_TIMEOUT
),
subtask_max_retries=self._config.get("scheduling", {}).get(
"subtask_max_retries", DEFAULT_SUBTASK_MAX_RETRIES
),
)
)

Expand Down
25 changes: 18 additions & 7 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,12 @@ def f1(c):


@pytest.fixture
def setup_session():
session = new_session(n_cpu=2, use_uvloop=False)
def setup_session(request):
param = getattr(request, "param", {})
config = param.get("config", {})
session = new_session(
backend=config.get("backend", "mars"), n_cpu=2, use_uvloop=False
)
assert session.get_web_endpoint() is not None

try:
Expand Down Expand Up @@ -706,6 +710,11 @@ def test_decref(setup_session):
_assert_storage_cleaned(session.session_id, worker_addr, StorageLevel.MEMORY)


def _assert_worker_pool_storage_cleaned(session):
worker_addr = session._session.client._cluster._worker_pools[0].external_address
_assert_storage_cleaned(session.session_id, worker_addr, StorageLevel.MEMORY)


def _cancel_when_execute(session, cancelled):
def run():
time.sleep(200)
Expand All @@ -720,8 +729,10 @@ def run():
ref_counts = session._get_ref_counts()
assert len(ref_counts) == 0

worker_addr = session._session.client._cluster._worker_pools[0].external_address
_assert_storage_cleaned(session.session_id, worker_addr, StorageLevel.MEMORY)

def _cancel_assert_when_execute(session, cancelled):
_assert_worker_pool_storage_cleaned(session)
_cancel_when_execute(session, cancelled)


class SlowTileAdd(TensorAdd):
Expand All @@ -745,9 +756,9 @@ def _cancel_when_tile(session, cancelled):
assert len(ref_counts) == 0


@pytest.mark.parametrize("test_func", [_cancel_when_execute, _cancel_when_tile])
def test_cancel(setup_session, test_func):
session = setup_session
@pytest.mark.parametrize("test_func", [_cancel_assert_when_execute, _cancel_when_tile])
def test_cancel(create_cluster, test_func):
session = get_default_session()

async def _new_cancel_event():
return asyncio.Event()
Expand Down
7 changes: 7 additions & 0 deletions mars/deploy/oscar/tests/test_ray_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..session import new_session
from ..tests import test_local
from ..tests.session import new_test_session
from ..tests.test_local import _cancel_when_tile, _cancel_when_execute
from .modules.utils import ( # noqa: F401; pylint: disable=unused-variable
cleanup_third_party_modules_output,
get_output_filenames,
Expand Down Expand Up @@ -129,3 +130,9 @@ def test_sync_execute(config):
@pytest.mark.asyncio
async def test_session_get_progress(ray_start_regular_shared2, create_cluster):
await test_local.test_session_get_progress(create_cluster)


@require_ray
@pytest.mark.parametrize("test_func", [_cancel_when_execute, _cancel_when_tile])
def test_cancel(ray_start_regular_shared2, create_cluster, test_func):
test_local.test_cancel(create_cluster, test_func)
11 changes: 4 additions & 7 deletions mars/services/task/execution/ray/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
from ..utils import get_band_resources_from_config


# the default times to retry subtask.
DEFAULT_SUBTASK_MAX_RETRIES = 3


@register_config_cls
class RayExecutionConfig(ExecutionConfig):
name = "ray"
Expand All @@ -41,12 +37,13 @@ def get_deploy_band_resources(self) -> List[Dict[str, Resource]]:
return []

def get_subtask_max_retries(self):
return self._ray_execution_config.get(
"subtask_max_retries", DEFAULT_SUBTASK_MAX_RETRIES
)
return self._ray_execution_config.get("subtask_max_retries")

def get_n_cpu(self):
return self._ray_execution_config["n_cpu"]

def get_n_worker(self):
return self._ray_execution_config["n_worker"]

def get_subtask_cancel_timeout(self):
return self._ray_execution_config.get("subtask_cancel_timeout")
61 changes: 60 additions & 1 deletion mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ def _optimize_subtask_graph(subtask_graph):
return _optimize_physical(subtask_graph)


async def _cancel_ray_task(obj_ref, kill_timeout: int = 3):
ray.cancel(obj_ref, force=False)
try:
await asyncio.to_thread(ray.get, obj_ref, timeout=kill_timeout)
except ray.exceptions.TaskCancelledError: # pragma: no cover
logger.info("Cancel ray task %s successfully.", obj_ref)
except BaseException as e:
logger.info(
"Failed to cancel ray task %s with exception %s, "
"force cancel the task by killing the worker.",
e,
obj_ref,
)
ray.cancel(obj_ref, force=True)


def execute_subtask(
task_id: str,
subtask_id: str,
Expand Down Expand Up @@ -159,6 +175,14 @@ def __init__(
self._pre_all_stages_tile_progress = 0
self._cur_stage_tile_progress = 0
self._cur_stage_output_object_refs = []
# This list records the output object ref number of subtasks, so with
# `self._cur_stage_output_object_refs` we can just call `ray.cancel`
# with one object ref to cancel a subtask instead of cancel all object
# refs. In this way we can reduce a lot of unnecessary calls of ray.
self._output_object_refs_nums = []
# For meta and data gc
self._execute_subtask_graph_aiotask = None
self._cancelled = False

@classmethod
async def create(
Expand Down Expand Up @@ -219,6 +243,9 @@ def destroy(self):
self._pre_all_stages_tile_progress = 1
self._cur_stage_tile_progress = 1
self._cur_stage_output_object_refs = []
self._output_object_refs_nums = []
self._execute_subtask_graph_aiotask = None
self._cancelled = None

@classmethod
@alru_cache(cache_exceptions=False)
Expand Down Expand Up @@ -267,6 +294,10 @@ async def execute_subtask_graph(
tile_context: TileContext,
context: Any = None,
) -> Dict[Chunk, ExecutionChunkResult]:
if self._cancelled is True: # pragma: no cover
raise asyncio.CancelledError()
self._execute_subtask_graph_aiotask = asyncio.current_task()

logger.info("Stage %s start.", stage_id)
task_context = self._task_context
output_meta_object_refs = []
Expand Down Expand Up @@ -307,6 +338,7 @@ async def execute_subtask_graph(
elif output_count == 1:
output_object_refs = [output_object_refs]
self._cur_stage_output_object_refs.extend(output_object_refs)
self._output_object_refs_nums.append(len(output_object_refs))
if output_meta_keys:
meta_object_ref, *output_object_refs = output_object_refs
# TODO(fyrestone): Fetch(not get) meta object here.
Expand Down Expand Up @@ -345,6 +377,7 @@ async def execute_subtask_graph(
# because current stage is finished, its progress is 1.
self._pre_all_stages_progress += self._cur_stage_tile_progress
self._cur_stage_output_object_refs.clear()
self._output_object_refs_nums.clear()
logger.info("Stage %s is complete.", stage_id)
return chunk_to_meta

Expand Down Expand Up @@ -416,7 +449,33 @@ async def get_progress(self) -> float:
return self._pre_all_stages_progress + stage_progress

async def cancel(self):
"""Cancel execution."""
"""
Cancel the task execution.
1. Try to cancel the `execute_subtask_graph`
2. Try to cancel the submitted subtasks by `ray.cancel`
"""
logger.info("Start to cancel task %s.", self._task)
if self._task is None:
return
self._cancelled = True
if (
self._execute_subtask_graph_aiotask is not None
and not self._execute_subtask_graph_aiotask.cancelled()
):
self._execute_subtask_graph_aiotask.cancel()
timeout = self._config.get_subtask_cancel_timeout()
subtask_num = len(self._output_object_refs_nums)
if subtask_num > 0:
pos = 0
obj_refs_to_be_cancelled_ = []
for i in range(0, subtask_num):
if i > 0:
pos += self._output_object_refs_nums[i - 1]
obj_refs_to_be_cancelled_.append(
_cancel_ray_task(self._cur_stage_output_object_refs[pos], timeout)
)
await asyncio.gather(*obj_refs_to_be_cancelled_)

async def _load_subtask_inputs(
self, stage_id: str, subtask: Subtask, chunk_graph: ChunkGraph, context: Dict
Expand Down

0 comments on commit 080d6b8

Please sign in to comment.