From c2a70095631472bc492e4580e8f2b383fc0fa7a4 Mon Sep 17 00:00:00 2001 From: hekaisheng Date: Wed, 11 Aug 2021 14:19:08 +0800 Subject: [PATCH] fix ut --- mars/services/storage/tests/test_transfer.py | 2 +- mars/services/storage/transfer.py | 27 +++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mars/services/storage/tests/test_transfer.py b/mars/services/storage/tests/test_transfer.py index 0d5658f8d1..6a239e5658 100644 --- a/mars/services/storage/tests/test_transfer.py +++ b/mars/services/storage/tests/test_transfer.py @@ -146,7 +146,7 @@ async def create_writers(self, data_sizes, level): await asyncio.sleep(3) - await super().create_writers(session_id, data_keys, data_sizes, level) + return await super().create_writers(session_id, data_keys, data_sizes, level) class MockSenderManagerActor2(SenderManagerActor): diff --git a/mars/services/storage/transfer.py b/mars/services/storage/transfer.py index e776544120..0f5622b4ad 100644 --- a/mars/services/storage/transfer.py +++ b/mars/services/storage/transfer.py @@ -168,8 +168,10 @@ async def send_batch_data(self, else: to_send_keys.append(data_key) - await self._send_data(receiver_ref, session_id, to_send_keys, level, block_size) - await receiver_ref.wait_transfer_done(session_id, to_wait_keys) + if to_send_keys: + await self._send_data(receiver_ref, session_id, to_send_keys, level, block_size) + if to_wait_keys: + await receiver_ref.wait_transfer_done(session_id, to_wait_keys) logger.debug('Finish sending data (%s, %s) to %s', session_id, data_keys, address) @@ -181,6 +183,7 @@ def __init__(self, self._quota_refs = quota_refs self._storage_handler = storage_handler_ref self._writing_keys = dict() + self._writing_refs = dict() self._lock = asyncio.Lock() async def __post_create__(self): @@ -192,6 +195,12 @@ async def __post_create__(self): def gen_uid(cls, band_name: str): return f'sender_receiver_{band_name}' + def _decref_writing_key(self, session_id: str, data_key: str): + self._writing_refs[(session_id, data_key)] -= 1 + if self._writing_refs[(session_id, data_key)] == 0: + del self._writing_refs[(session_id, data_key)] + del self._writing_keys[(session_id, data_key)] + async def create_writers(self, session_id: str, data_keys: List[str], @@ -200,22 +209,24 @@ async def create_writers(self, async with self._lock: tasks = dict() data_key_to_size = dict() - is_writing_tasks = [] + being_processed = [] for data_key, data_size in zip(data_keys, data_sizes): data_key_to_size[data_key] = data_size if (session_id, data_key) not in self._key_to_writer_info: - is_writing_tasks.append(False) + being_processed.append(False) tasks[data_key] = self._storage_handler.open_writer.delay( session_id, data_key, data_size, level, request_quota=False) else: - is_writing_tasks.append(True) + being_processed.append(True) + self._writing_refs[(session_id, data_key)] += 1 if tasks: writers = await self._storage_handler.open_writer.batch(*tuple(tasks.values())) for data_key, writer in zip(tasks, writers): self._key_to_writer_info[ (session_id, data_key)] = (writer, data_key_to_size[data_key], level) self._writing_keys[(session_id, data_key)] = asyncio.Event() - return is_writing_tasks + self._writing_refs[(session_id, data_key)] = 1 + return being_processed async def open_writers(self, session_id: str, @@ -251,6 +262,7 @@ async def do_write(self, message: TransferMessage): for data_key in finished_keys: self._key_to_writer_info.pop((session_id, data_key)) self._writing_keys[(session_id, data_key)].set() + self._decref_writing_key(session_id, data_key) async def receive_part_data(self, message: TransferMessage): try: @@ -266,7 +278,10 @@ async def receive_part_data(self, message: TransferMessage): await writer.clean_up() self._key_to_writer_info.pop(( message.session_id, data_key)) + raise async def wait_transfer_done(self, session_id, data_keys): await asyncio.gather(*[self._writing_keys[(session_id, key)].wait() for key in data_keys]) + [self._decref_writing_key(session_id, data_key) + for data_key in data_keys]