diff --git a/mars/services/storage/tests/test_transfer.py b/mars/services/storage/tests/test_transfer.py index 6a239e5658..4dec9f9337 100644 --- a/mars/services/storage/tests/test_transfer.py +++ b/mars/services/storage/tests/test_transfer.py @@ -183,6 +183,9 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver): data1 = np.random.rand(10, 10) await storage_handler1.put('mock', 'data_key1', data1, StorageLevel.MEMORY) + data2 = pd.DataFrame(np.random.rand(100, 100)) + await storage_handler1.put('mock', 'data_key2', + data2, StorageLevel.MEMORY) used_before = (await quota_refs[StorageLevel.MEMORY].get_quota())[1] @@ -207,6 +210,19 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver): get_data = await storage_handler2.get('mock', 'data_key1') np.testing.assert_array_equal(data1, get_data) + # cancel when fetch the same data Simultaneously + if mock_sender is MockSenderManagerActor: + send_task1 = asyncio.create_task(sender_actor.send_batch_data( + 'mock', ['data_key2'], worker_address_2, StorageLevel.MEMORY)) + send_task2 = asyncio.create_task(sender_actor.send_batch_data( + 'mock', ['data_key2'], worker_address_2, StorageLevel.MEMORY)) + await asyncio.sleep(0.5) + send_task1.cancel() + with pytest.raises(asyncio.CancelledError): + await send_task2 + with pytest.raises(DataNotExist): + await storage_handler2.get('mock', 'data_key2') + @pytest.mark.asyncio async def test_transfer_same_tasks(create_actors): diff --git a/mars/services/storage/transfer.py b/mars/services/storage/transfer.py index 0f5622b4ad..12462b5734 100644 --- a/mars/services/storage/transfer.py +++ b/mars/services/storage/transfer.py @@ -187,7 +187,7 @@ def __init__(self, self._lock = asyncio.Lock() async def __post_create__(self): - if self._storage_handler is None: # for test + if self._storage_handler is None: # for test self._storage_handler = await mo.actor_ref( self.address, StorageHandlerActor.gen_uid('numa-0')) @@ -261,6 +261,7 @@ async def do_write(self, message: TransferMessage): await asyncio.gather(*close_tasks) for data_key in finished_keys: self._key_to_writer_info.pop((session_id, data_key)) + self._writing_keys[(message.session_id, data_key)].is_success = True self._writing_keys[(session_id, data_key)].set() self._decref_writing_key(session_id, data_key) @@ -276,6 +277,8 @@ async def receive_part_data(self, message: TransferMessage): await self._storage_handler.delete( message.session_id, data_key, error='ignore') await writer.clean_up() + self._writing_keys[(message.session_id, data_key)].is_success = False + self._writing_keys[(message.session_id, data_key)].set() self._key_to_writer_info.pop(( message.session_id, data_key)) raise @@ -283,5 +286,11 @@ async def receive_part_data(self, message: TransferMessage): async def wait_transfer_done(self, session_id, data_keys): await asyncio.gather(*[self._writing_keys[(session_id, key)].wait() for key in data_keys]) - [self._decref_writing_key(session_id, data_key) - for data_key in data_keys] + try: + if not all(self._writing_keys[(session_id, key)].is_success + for key in data_keys): + raise asyncio.CancelledError(f'Transfer cancelled for' + f' data {session_id, data_keys}') + finally: + for data_key in data_keys: + self._decref_writing_key(session_id, data_key)