forked from Tribler/tribler
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cherry-pick: Fixes Tribler#7972: UDP server stops accepting datagrams…
… from any clients after a single client disconnects (cherry picked from commit 449b864)
- Loading branch information
Showing
3 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
src/tribler/core/utilities/asyncio_fixes/proactor_recvfrom_patch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from asyncio.log import logger | ||
|
||
try: | ||
import _overlapped | ||
except ImportError: | ||
_overlapped = None | ||
|
||
|
||
NULL = 0 | ||
|
||
ERROR_PORT_UNREACHABLE = 1234 # _overlapped.ERROR_PORT_UNREACHABLE, available in Python >= 3.11 | ||
ERROR_NETNAME_DELETED = 64 | ||
ERROR_OPERATION_ABORTED = 995 | ||
|
||
patch_applied = False | ||
|
||
|
||
def apply_proactor_recvfrom_patch(): # pragma: no cover | ||
global patch_applied # pylint: disable=global-statement | ||
if patch_applied: | ||
return | ||
|
||
from asyncio import IocpProactor # pylint: disable=import-outside-toplevel | ||
|
||
IocpProactor.recvfrom = patched_recvfrom | ||
|
||
patch_applied = True | ||
logger.info("Patched IocpProactor.recvfrom to handle ERROR_PORT_UNREACHABLE") | ||
|
||
|
||
# pylint: disable=protected-access | ||
|
||
|
||
def patched_recvfrom(self, conn, nbytes, flags=0): | ||
self._register_with_iocp(conn) | ||
ov = _overlapped.Overlapped(NULL) | ||
try: | ||
ov.WSARecvFrom(conn.fileno(), nbytes, flags) | ||
except BrokenPipeError: | ||
return self._result((b'', None)) | ||
|
||
def finish_recvfrom(trans, key, ov, error_class=OSError): # pylint: disable=unused-argument | ||
try: | ||
return ov.getresult() | ||
except error_class as exc: | ||
if exc.winerror in (ERROR_NETNAME_DELETED, ERROR_OPERATION_ABORTED): | ||
raise ConnectionResetError(*exc.args) # pylint: disable=raise-missing-from | ||
|
||
# ******************** START OF THE PATCH ******************** | ||
# WSARecvFrom will report ERROR_PORT_UNREACHABLE when the same | ||
# socket was used to send to an address that is not listening. | ||
if exc.winerror == ERROR_PORT_UNREACHABLE: | ||
return b'', None # ignore the error | ||
# ******************** END OF THE PATCH ********************** | ||
|
||
raise | ||
|
||
return self._register(ov, conn, finish_recvfrom) |
82 changes: 82 additions & 0 deletions
82
src/tribler/core/utilities/asyncio_fixes/tests/test_proactor_recvfrom_patch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch import ERROR_NETNAME_DELETED, \ | ||
ERROR_OPERATION_ABORTED, ERROR_PORT_UNREACHABLE, patched_recvfrom | ||
|
||
|
||
# pylint: disable=protected-access | ||
|
||
|
||
@patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch._overlapped') | ||
def test_patched_recvfrom_broken_pipe_error(overlapped): | ||
proactor, conn, nbytes, flags, ov = (Mock() for _ in range(5)) | ||
overlapped.Overlapped.return_value = ov | ||
conn.fileno.return_value = Mock() | ||
ov.WSARecvFrom.side_effect = BrokenPipeError() | ||
proactor._result.return_value = Mock() | ||
|
||
result = patched_recvfrom(proactor, conn, nbytes, flags) | ||
|
||
proactor._register_with_iocp.assert_called_with(conn) | ||
overlapped.Overlapped.assert_called_with(0) | ||
ov.WSARecvFrom.assert_called_with(conn.fileno.return_value, nbytes, flags) | ||
proactor._result.assert_called_with((b'', None)) | ||
assert result is proactor._result.return_value | ||
|
||
|
||
@patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch._overlapped') | ||
def test_patched_recvfrom(overlapped): | ||
proactor, conn, nbytes, flags, ov, trans, key = (Mock() for _ in range(7)) | ||
overlapped.Overlapped.return_value = ov | ||
conn.fileno.return_value = Mock() | ||
proactor._register.return_value = Mock() | ||
|
||
result = patched_recvfrom(proactor, conn, nbytes, flags) | ||
proactor._register.assert_called_once() | ||
assert result is proactor._register.return_value | ||
args = proactor._register.call_args.args | ||
assert args[:2] == (ov, conn) and len(args) == 3 | ||
|
||
finish_recvfrom = args[2] | ||
|
||
class OSErrorMock(Exception): | ||
def __init__(self, winerror): | ||
self.winerror = winerror | ||
|
||
with patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch.OSError', 'OSErrorMock'): | ||
|
||
# Should raise ConnectionResetError if ov.getresult() raises OSError with winerror=ERROR_NETNAME_DELETED | ||
|
||
ov.getresult.assert_not_called() | ||
ov.getresult.side_effect = OSErrorMock(ERROR_NETNAME_DELETED) | ||
with pytest.raises(ConnectionResetError): | ||
finish_recvfrom(trans, key, ov, error_class=OSErrorMock) | ||
|
||
# Should raise ConnectionResetError if ov.getresult() raises OSError with winerror=ERROR_OPERATION_ABORTED | ||
|
||
ov.getresult.side_effect = OSErrorMock(ERROR_OPERATION_ABORTED) | ||
with pytest.raises(ConnectionResetError): | ||
finish_recvfrom(trans, key, ov, error_class=OSErrorMock) | ||
|
||
# Should return empty result if ov.getresult() raises OSError with winerror=ERROR_PORT_UNREACHABLE | ||
|
||
ov.getresult.side_effect = OSErrorMock(ERROR_PORT_UNREACHABLE) | ||
result = finish_recvfrom(trans, key, ov, error_class=OSErrorMock) | ||
assert result == (b'', None) | ||
|
||
# Should reraise any other OSError raised by ov.getresult() | ||
|
||
ov.getresult.side_effect = OSErrorMock(-1) | ||
with pytest.raises(OSErrorMock): | ||
finish_recvfrom(trans, key, ov, error_class=OSErrorMock) | ||
|
||
# Should return result of ov.getresult() if no exceptions arisen | ||
|
||
ov.getresult.side_effect = None | ||
ov.getresult.return_value = Mock() | ||
result = finish_recvfrom(trans, key, ov) | ||
assert result is ov.getresult.return_value | ||
|
||
assert ov.getresult.call_count == 5 |