Skip to content

Commit

Permalink
[BACKPORT] [Ray] Destroy Ray executor when the task finish (#3049) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi authored May 24, 2022
1 parent 433ce03 commit 4512041
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
3 changes: 3 additions & 0 deletions mars/services/task/execution/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ async def create(
**kwargs,
)

def destroy(self):
"""Destroy the executor."""

async def __aenter__(self):
"""Called when begin to execute the task."""

Expand Down
21 changes: 21 additions & 0 deletions mars/services/task/execution/ray/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,27 @@ async def create(
meta_api,
)

# noinspection DuplicatedCode
def destroy(self):
self._config = None
self._task = None
self._tile_context = None
self._task_context = None
self._task_state_actor = None
self._ray_executor = None

# api
self._lifecycle_api = None
self._meta_api = None

self._available_band_resources = None

# For progress
self._pre_all_stages_progress = 1
self._pre_all_stages_tile_progress = 1
self._cur_stage_tile_progress = 1
self._cur_stage_output_object_refs = []

@classmethod
@alru_cache(cache_exceptions=False)
async def _get_apis(cls, session_id: str, address: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import Counter

import numpy as np
import pandas as pd
import pytest

from ...... import tensor as mt

from ......core import TileContext
from ......core.graph import TileableGraph, TileableGraphBuilder, ChunkGraphBuilder
from ......serialization import serialize
from ......tests.core import require_ray, mock
from ......utils import lazy_import, get_chunk_params
from .....context import ThreadedServiceContext
from ....core import new_task_id
from ....core import new_task_id, Task
from ..config import RayExecutionConfig
from ..context import (
RayExecutionContext,
RayRemoteObjectManager,
_RayRemoteObjectContext,
)
from ..executor import execute_subtask
from ..executor import execute_subtask, RayTaskExecutor
from ..fetcher import RayFetcher

ray = lazy_import("ray")
Expand All @@ -41,6 +45,48 @@ def _gen_subtask_chunk_graph(t):
return next(ChunkGraphBuilder(graph, fuse_enabled=False).build())


class MockRayTaskExecutor(RayTaskExecutor):
def __init__(self, *args, **kwargs):
self._set_attrs = Counter()
super().__init__(*args, **kwargs)

@staticmethod
def _get_ray_executor():
# Export remote function once.
return None

def set_attr_counter(self):
return self._set_attrs

def __setattr__(self, key, value):
super().__setattr__(key, value)
self._set_attrs[key] += 1


def test_ray_executor_destroy():
task = Task("mock_task", "mock_session")
config = RayExecutionConfig.from_execution_config({"backend": "mars"})
executor = MockRayTaskExecutor(
config=config,
task=task,
tile_context=TileContext(),
task_context={},
task_state_actor=None,
lifecycle_api=None,
meta_api=None,
)
counter = executor.set_attr_counter()
assert len(counter) > 0
keys = executor.__dict__.keys()
assert counter.keys() >= keys
counter.clear()
executor.destroy()
keys = set(keys) - {"_set_attrs"}
assert counter.keys() == keys, "Some keys are not reset in destroy()."
for k, v in counter.items():
assert v == 1


def test_ray_execute_subtask_basic():
raw = np.ones((10, 10))
raw_expect = raw + 1
Expand Down
1 change: 1 addition & 0 deletions mars/services/task/supervisor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def dump_subtask_graph(self):
f.write(dot)

def _finish(self):
self._executor.destroy()
self.done.set()
if self._dump_subtask_graph:
self.dump_subtask_graph()
Expand Down

0 comments on commit 4512041

Please sign in to comment.