Skip to content

Commit

Permalink
Refine lifecycle api to support incref or decref with ref counts (mar…
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored Apr 18, 2022
1 parent 00dbf41 commit b51afae
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 57 deletions.
6 changes: 5 additions & 1 deletion mars/services/lifecycle/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@

class AbstractLifecycleAPI(ABC):
@abstractmethod
async def decref_tileables(self, tileable_keys: List[str]):
async def decref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
"""
Decref tileables.
Parameters
----------
tileable_keys : list
List of tileable keys.
counts: list
List of ref count.
"""

@abstractmethod
Expand Down
32 changes: 25 additions & 7 deletions mars/services/lifecycle/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,35 @@ async def batch_track(self, args_list, kwargs_list):
tracks.append(self._lifecycle_tracker_ref.track.delay(*args, **kwargs))
return await self._lifecycle_tracker_ref.track.batch(*tracks)

async def incref_tileables(self, tileable_keys: List[str]):
async def incref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
"""
Incref tileables.
Parameters
----------
tileable_keys : list
List of tileable keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.incref_tileables(tileable_keys)
return await self._lifecycle_tracker_ref.incref_tileables(
tileable_keys, counts=counts
)

async def decref_tileables(self, tileable_keys: List[str]):
async def decref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
"""
Decref tileables.
Parameters
----------
tileable_keys : list
List of tileable keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.decref_tileables(tileable_keys)

Expand All @@ -111,27 +121,35 @@ async def get_tileable_ref_counts(self, tileable_keys: List[str]) -> List[int]:
"""
return await self._lifecycle_tracker_ref.get_tileable_ref_counts(tileable_keys)

async def incref_chunks(self, chunk_keys: List[str]):
async def incref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
"""
Incref chunks.
Parameters
----------
chunk_keys : list
List of chunk keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.incref_chunks(chunk_keys)
return await self._lifecycle_tracker_ref.incref_chunks(
chunk_keys, counts=counts
)

async def decref_chunks(self, chunk_keys: List[str]):
async def decref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
"""
Decref chunks
Parameters
----------
chunk_keys : list
List of chunk keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.decref_chunks(chunk_keys)
return await self._lifecycle_tracker_ref.decref_chunks(
chunk_keys, counts=counts
)

async def get_chunk_ref_counts(self, chunk_keys: List[str]) -> List[int]:
"""
Expand Down
14 changes: 11 additions & 3 deletions mars/services/lifecycle/api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ async def _get_oscar_lifecycle_api(self, session_id: str):
@web_api("", method="post", arg_filter={"action": "decref_tileables"})
async def decref_tileables(self, session_id: str):
tileable_keys = self.get_argument("tileable_keys").split(",")
counts = self.get_argument("counts", None)
if counts:
counts = [int(c) for c in counts.split(",")]

oscar_api = await self._get_oscar_lifecycle_api(session_id)
await oscar_api.decref_tileables(tileable_keys)
await oscar_api.decref_tileables(tileable_keys, counts=counts)

@web_api("", method="get", arg_filter={"action": "get_all_chunk_ref_counts"})
async def get_all_chunk_ref_counts(self, session_id: str):
Expand All @@ -52,15 +55,20 @@ def __init__(
self._address = address.rstrip("/")
self.request_rewriter = request_rewriter

async def decref_tileables(self, tileable_keys: List[str]):
async def decref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
path = f"{self._address}/api/session/{self._session_id}/lifecycle"
params = dict(action="decref_tileables")
counts = (
f"&counts={','.join(str(c) for c in counts)}" if counts is not None else ""
)
await self._request_url(
path=path,
method="POST",
params=params,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data="tileable_keys=" + ",".join(tileable_keys),
data="tileable_keys=" + ",".join(tileable_keys) + counts,
)

async def get_all_chunk_ref_counts(self) -> Dict[str, int]:
Expand Down
7 changes: 7 additions & 0 deletions mars/services/lifecycle/supervisor/tests/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,18 @@ async def test_tracker():

await tracker.track(tileable_key, chunk_keys)
await tracker.incref_tileables([tileable_key])
await tracker.incref_tileables([tileable_key], [2])
await tracker.incref_chunks(chunk_keys[:2])
await tracker.incref_chunks(chunk_keys[:2], [3, 3])
await tracker.decref_chunks(chunk_keys[:2])
await tracker.decref_chunks(chunk_keys[:2], [3, 3])
await tracker.decref_tileables([tileable_key])
await tracker.decref_tileables([tileable_key], [2])
assert len(await tracker.get_all_chunk_ref_counts()) == 0

with pytest.raises(ValueError):
await tracker.incref_tileables([tileable_key], [2, 3])

for chunk_key in chunk_keys:
with pytest.raises(KeyError):
await meta_api.get_chunk_meta(chunk_key)
Expand Down
46 changes: 32 additions & 14 deletions mars/services/lifecycle/supervisor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import itertools
import logging

from collections import defaultdict
Expand Down Expand Up @@ -69,29 +70,40 @@ def track(self, tileable_key: str, chunk_keys: List[str]):
if incref_chunk_keys:
self.incref_chunks(incref_chunk_keys)

def incref_chunks(self, chunk_keys: List[str]):
@classmethod
def _check_ref_counts(cls, keys: List[str], ref_counts: List[int]):
if ref_counts is not None and len(keys) != len(ref_counts):
raise ValueError(
f"`ref_counts` should have same size as `keys`, expect {len(keys)}, got {len(ref_counts)}"
)

def incref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
logger.debug("Increase reference count for chunks %s", chunk_keys)
for chunk_key in chunk_keys:
self._chunk_ref_counts[chunk_key] += 1
self._check_ref_counts(chunk_keys, counts)
counts = counts if counts is not None else itertools.repeat(1)
for chunk_key, count in zip(chunk_keys, counts):
self._chunk_ref_counts[chunk_key] += count

def _get_remove_chunk_keys(self, chunk_keys: List[str]):
def _get_remove_chunk_keys(self, chunk_keys: List[str], counts: List[int] = None):
to_remove_chunk_keys = []
for chunk_key in chunk_keys:
counts = counts if counts is not None else itertools.repeat(1)
for chunk_key, count in zip(chunk_keys, counts):
ref_count = self._chunk_ref_counts[chunk_key]
ref_count -= 1
ref_count -= count
assert ref_count >= 0, f"chunk key {chunk_key} will have negative ref count"
self._chunk_ref_counts[chunk_key] = ref_count
if ref_count == 0:
# remove
to_remove_chunk_keys.append(chunk_key)
return to_remove_chunk_keys

async def decref_chunks(self, chunk_keys: List[str]):
async def decref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
self._check_ref_counts(chunk_keys, counts)
logger.debug(
"Decrease reference count for chunks %s",
{ck: self._chunk_ref_counts[ck] for ck in chunk_keys},
)
to_remove_chunk_keys = self._get_remove_chunk_keys(chunk_keys)
to_remove_chunk_keys = self._get_remove_chunk_keys(chunk_keys, counts)
# make _remove_chunks release actor lock so that multiple `decref_chunks` can run concurrently.
yield self._remove_chunks(to_remove_chunk_keys)

Expand Down Expand Up @@ -160,11 +172,13 @@ def get_all_chunk_ref_counts(self) -> Dict[str, int]:
result[chunk_key] = ref_count
return result

def incref_tileables(self, tileable_keys: List[str]):
for tileable_key in tileable_keys:
def incref_tileables(self, tileable_keys: List[str], counts: List[int] = None):
self._check_ref_counts(tileable_keys, counts)
counts = counts if counts is not None else itertools.repeat(1)
for tileable_key, count in zip(tileable_keys, counts):
if tileable_key not in self._tileable_key_to_chunk_keys:
raise TileableNotTracked(f"tileable {tileable_key} not tracked before")
self._tileable_ref_counts[tileable_key] += 1
self._tileable_ref_counts[tileable_key] += count
incref_chunk_keys = self._tileable_key_to_chunk_keys[tileable_key]
# incref chunks for this tileable
logger.debug(
Expand All @@ -174,12 +188,16 @@ def incref_tileables(self, tileable_keys: List[str]):
)
self.incref_chunks(incref_chunk_keys)

async def decref_tileables(self, tileable_keys: List[str]):
async def decref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
self._check_ref_counts(tileable_keys, counts)
decref_chunk_keys = []
for tileable_key in tileable_keys:
counts = counts if counts is not None else itertools.repeat(1)
for tileable_key, count in zip(tileable_keys, counts):
if tileable_key not in self._tileable_key_to_chunk_keys:
raise TileableNotTracked(f"tileable {tileable_key} not tracked before")
self._tileable_ref_counts[tileable_key] -= 1
self._tileable_ref_counts[tileable_key] -= count

decref_chunk_keys.extend(self._tileable_key_to_chunk_keys[tileable_key])
logger.debug(
Expand Down
80 changes: 48 additions & 32 deletions mars/services/task/execution/mars/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import logging
import sys
from collections import defaultdict
from typing import Dict, List, Optional

from ..... import oscar as mo
Expand Down Expand Up @@ -330,34 +331,37 @@ async def _decref_result_tileables(self):

async def _incref_stage(self, stage_processor: "TaskStageProcessor"):
subtask_graph = stage_processor.subtask_graph
incref_chunk_keys = []
incref_chunk_key_to_counts = defaultdict(lambda: 0)
for subtask in subtask_graph:
# for subtask has successors, incref number of successors
n = subtask_graph.count_successors(subtask)
for c in subtask.chunk_graph.results:
incref_chunk_keys.extend([c.key] * n)
incref_chunk_key_to_counts[c.key] += n
# process reducer, incref mapper chunks
for pre_graph in subtask_graph.iter_predecessors(subtask):
for chk in pre_graph.chunk_graph.results:
if isinstance(chk.op, ShuffleProxy):
n_reducer = _get_n_reducer(subtask)
incref_chunk_keys.extend(
[map_chunk.key for map_chunk in chk.inputs] * n_reducer
)
for map_chunk in chk.inputs:
incref_chunk_key_to_counts[map_chunk.key] += n_reducer
result_chunks = stage_processor.chunk_graph.result_chunks
incref_chunk_keys.extend([c.key for c in result_chunks])
for c in result_chunks:
incref_chunk_key_to_counts[c.key] += 1
logger.debug(
"Incref chunks for stage %s: %s",
stage_processor.stage_id,
incref_chunk_keys,
incref_chunk_key_to_counts,
)
await self._lifecycle_api.incref_chunks(
list(incref_chunk_key_to_counts),
counts=list(incref_chunk_key_to_counts.values()),
)
await self._lifecycle_api.incref_chunks(incref_chunk_keys)

@classmethod
def _get_decref_stage_chunk_keys(
def _get_decref_stage_chunk_key_to_counts(
cls, stage_processor: "TaskStageProcessor"
) -> List[str]:
decref_chunk_keys = []
) -> Dict[str, int]:
decref_chunk_key_to_counts = defaultdict(lambda: 0)
error_or_cancelled = stage_processor.error_or_cancelled()
if stage_processor.subtask_graph:
subtask_graph = stage_processor.subtask_graph
Expand All @@ -369,32 +373,42 @@ def _get_decref_stage_chunk_keys(
stage_processor.decref_subtask.add(subtask.subtask_id)
# if subtask not executed, rollback incref of predecessors
for inp_subtask in subtask_graph.predecessors(subtask):
decref_chunk_keys.extend(
[c.key for c in inp_subtask.chunk_graph.results]
)
for c in inp_subtask.chunk_graph.results:
decref_chunk_key_to_counts[c.key] += 1
# decref result of chunk graphs
decref_chunk_keys.extend(
[c.key for c in stage_processor.chunk_graph.results]
)
return decref_chunk_keys
for c in stage_processor.chunk_graph.results:
decref_chunk_key_to_counts[c.key] += 1
return decref_chunk_key_to_counts

@mo.extensible
async def _decref_stage(self, stage_processor: "TaskStageProcessor"):
decref_chunk_keys = self._get_decref_stage_chunk_keys(stage_processor)
decref_chunk_key_to_counts = self._get_decref_stage_chunk_key_to_counts(
stage_processor
)
logger.debug(
"Decref chunks when stage %s finish: %s",
stage_processor.stage_id,
decref_chunk_keys,
decref_chunk_key_to_counts,
)
await self._lifecycle_api.decref_chunks(
list(decref_chunk_key_to_counts),
counts=list(decref_chunk_key_to_counts.values()),
)
await self._lifecycle_api.decref_chunks(decref_chunk_keys)

@_decref_stage.batch
async def _decref_stage(self, args_list, kwargs_list):
decref_chunk_keys = []
decref_chunk_key_to_counts = defaultdict(lambda: 0)
for args, kwargs in zip(args_list, kwargs_list):
decref_chunk_keys.extend(self._get_decref_stage_chunk_keys(*args, **kwargs))
logger.debug("Decref chunks when stages finish: %s", decref_chunk_keys)
await self._lifecycle_api.decref_chunks(decref_chunk_keys)
chunk_key_to_counts = self._get_decref_stage_chunk_key_to_counts(
*args, **kwargs
)
for k, c in chunk_key_to_counts.items():
decref_chunk_key_to_counts[k] += c
logger.debug("Decref chunks when stages finish: %s", decref_chunk_key_to_counts)
await self._lifecycle_api.decref_chunks(
list(decref_chunk_key_to_counts),
counts=list(decref_chunk_key_to_counts.values()),
)

async def _decref_input_subtasks(
self, subtask: Subtask, subtask_graph: SubtaskGraph
Expand All @@ -406,22 +420,24 @@ async def _decref_input_subtasks(
await self._subtask_decref_events[subtask.subtask_id].wait()
return

decref_chunk_keys = []
decref_chunk_key_to_counts = defaultdict(lambda: 0)
for in_subtask in subtask_graph.iter_predecessors(subtask):
for result_chunk in in_subtask.chunk_graph.results:
# for reducer chunk, decref mapper chunks
if isinstance(result_chunk.op, ShuffleProxy):
n_reducer = _get_n_reducer(subtask)
decref_chunk_keys.extend(
[inp.key for inp in result_chunk.inputs] * n_reducer
)
decref_chunk_keys.append(result_chunk.key)
for inp in result_chunk.inputs:
decref_chunk_key_to_counts[inp.key] += n_reducer
decref_chunk_key_to_counts[result_chunk.key] += 1
logger.debug(
"Decref chunks %s when subtask %s finish",
decref_chunk_keys,
decref_chunk_key_to_counts,
subtask.subtask_id,
)
await self._lifecycle_api.decref_chunks(decref_chunk_keys)
await self._lifecycle_api.decref_chunks(
list(decref_chunk_key_to_counts),
counts=list(decref_chunk_key_to_counts.values()),
)

# `set_subtask_result` will be called when subtask finished
# but report progress will call set_subtask_result too,
Expand Down

0 comments on commit b51afae

Please sign in to comment.