Skip to content

Commit

Permalink
Fix cancel issues
Browse files Browse the repository at this point in the history
  • Loading branch information
hekaisheng committed Aug 11, 2021
1 parent c2a7009 commit df8811a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
16 changes: 16 additions & 0 deletions mars/services/storage/tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
data1 = np.random.rand(10, 10)
await storage_handler1.put('mock', 'data_key1',
data1, StorageLevel.MEMORY)
data2 = pd.DataFrame(np.random.rand(100, 100))
await storage_handler1.put('mock', 'data_key2',
data2, StorageLevel.MEMORY)

used_before = (await quota_refs[StorageLevel.MEMORY].get_quota())[1]

Expand All @@ -207,6 +210,19 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
get_data = await storage_handler2.get('mock', 'data_key1')
np.testing.assert_array_equal(data1, get_data)

# cancel when fetch the same data Simultaneously
if mock_sender is MockSenderManagerActor:
send_task1 = asyncio.create_task(sender_actor.send_batch_data(
'mock', ['data_key2'], worker_address_2, StorageLevel.MEMORY))
send_task2 = asyncio.create_task(sender_actor.send_batch_data(
'mock', ['data_key2'], worker_address_2, StorageLevel.MEMORY))
await asyncio.sleep(0.5)
send_task1.cancel()
with pytest.raises(asyncio.CancelledError):
await send_task2
with pytest.raises(DataNotExist):
await storage_handler2.get('mock', 'data_key2')


@pytest.mark.asyncio
async def test_transfer_same_tasks(create_actors):
Expand Down
15 changes: 12 additions & 3 deletions mars/services/storage/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(self,
self._lock = asyncio.Lock()

async def __post_create__(self):
if self._storage_handler is None: # for test
if self._storage_handler is None: # for test
self._storage_handler = await mo.actor_ref(
self.address, StorageHandlerActor.gen_uid('numa-0'))

Expand Down Expand Up @@ -261,6 +261,7 @@ async def do_write(self, message: TransferMessage):
await asyncio.gather(*close_tasks)
for data_key in finished_keys:
self._key_to_writer_info.pop((session_id, data_key))
self._writing_keys[(message.session_id, data_key)].is_success = True
self._writing_keys[(session_id, data_key)].set()
self._decref_writing_key(session_id, data_key)

Expand All @@ -276,12 +277,20 @@ async def receive_part_data(self, message: TransferMessage):
await self._storage_handler.delete(
message.session_id, data_key, error='ignore')
await writer.clean_up()
self._writing_keys[(message.session_id, data_key)].is_success = False
self._writing_keys[(message.session_id, data_key)].set()
self._key_to_writer_info.pop((
message.session_id, data_key))
raise

async def wait_transfer_done(self, session_id, data_keys):
await asyncio.gather(*[self._writing_keys[(session_id, key)].wait()
for key in data_keys])
[self._decref_writing_key(session_id, data_key)
for data_key in data_keys]
try:
if not all(self._writing_keys[(session_id, key)].is_success
for key in data_keys):
raise asyncio.CancelledError(f'Transfer cancelled for'
f' data {session_id, data_keys}')
finally:
for data_key in data_keys:
self._decref_writing_key(session_id, data_key)

0 comments on commit df8811a

Please sign in to comment.