Skip to content

Commit

Permalink
Cherry-pick: Fixes Tribler#7972: UDP server stops accepting datagrams…
Browse files Browse the repository at this point in the history
… from any clients after a single client disconnects

(cherry picked from commit 449b864)
  • Loading branch information
kozlovsky committed Apr 19, 2024
1 parent f9a6db6 commit 4b454d5
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/run_tribler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tribler.core.sentry_reporter.sentry_reporter import SentryReporter, SentryStrategy
from tribler.core.sentry_reporter.sentry_scrubber import SentryScrubber
from tribler.core.utilities.asyncio_fixes.finish_accept_patch import apply_finish_accept_patch
from tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch import apply_proactor_recvfrom_patch
from tribler.core.utilities.slow_coro_detection.main_thread_stack_tracking import start_main_thread_stack_tracing
from tribler.core.utilities.osutils import get_root_state_directory
from tribler.core.utilities.utilities import is_frozen
Expand Down Expand Up @@ -95,6 +96,7 @@ def main():
if parsed_args.core:
if sys.platform == 'win32':
apply_finish_accept_patch()
apply_proactor_recvfrom_patch()

from tribler.core.utilities.pony_utils import track_slow_db_sessions
track_slow_db_sessions()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from asyncio.log import logger

try:
import _overlapped
except ImportError:
_overlapped = None


NULL = 0

ERROR_PORT_UNREACHABLE = 1234 # _overlapped.ERROR_PORT_UNREACHABLE, available in Python >= 3.11
ERROR_NETNAME_DELETED = 64
ERROR_OPERATION_ABORTED = 995

patch_applied = False


def apply_proactor_recvfrom_patch(): # pragma: no cover
global patch_applied # pylint: disable=global-statement
if patch_applied:
return

from asyncio import IocpProactor # pylint: disable=import-outside-toplevel

IocpProactor.recvfrom = patched_recvfrom

patch_applied = True
logger.info("Patched IocpProactor.recvfrom to handle ERROR_PORT_UNREACHABLE")


# pylint: disable=protected-access


def patched_recvfrom(self, conn, nbytes, flags=0):
self._register_with_iocp(conn)
ov = _overlapped.Overlapped(NULL)
try:
ov.WSARecvFrom(conn.fileno(), nbytes, flags)
except BrokenPipeError:
return self._result((b'', None))

def finish_recvfrom(trans, key, ov, error_class=OSError): # pylint: disable=unused-argument
try:
return ov.getresult()
except error_class as exc:
if exc.winerror in (ERROR_NETNAME_DELETED, ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args) # pylint: disable=raise-missing-from

# ******************** START OF THE PATCH ********************
# WSARecvFrom will report ERROR_PORT_UNREACHABLE when the same
# socket was used to send to an address that is not listening.
if exc.winerror == ERROR_PORT_UNREACHABLE:
return b'', None # ignore the error
# ******************** END OF THE PATCH **********************

raise

return self._register(ov, conn, finish_recvfrom)
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from unittest.mock import Mock, patch

import pytest

from tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch import ERROR_NETNAME_DELETED, \
ERROR_OPERATION_ABORTED, ERROR_PORT_UNREACHABLE, patched_recvfrom


# pylint: disable=protected-access


@patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch._overlapped')
def test_patched_recvfrom_broken_pipe_error(overlapped):
proactor, conn, nbytes, flags, ov = (Mock() for _ in range(5))
overlapped.Overlapped.return_value = ov
conn.fileno.return_value = Mock()
ov.WSARecvFrom.side_effect = BrokenPipeError()
proactor._result.return_value = Mock()

result = patched_recvfrom(proactor, conn, nbytes, flags)

proactor._register_with_iocp.assert_called_with(conn)
overlapped.Overlapped.assert_called_with(0)
ov.WSARecvFrom.assert_called_with(conn.fileno.return_value, nbytes, flags)
proactor._result.assert_called_with((b'', None))
assert result is proactor._result.return_value


@patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch._overlapped')
def test_patched_recvfrom(overlapped):
proactor, conn, nbytes, flags, ov, trans, key = (Mock() for _ in range(7))
overlapped.Overlapped.return_value = ov
conn.fileno.return_value = Mock()
proactor._register.return_value = Mock()

result = patched_recvfrom(proactor, conn, nbytes, flags)
proactor._register.assert_called_once()
assert result is proactor._register.return_value
args = proactor._register.call_args.args
assert args[:2] == (ov, conn) and len(args) == 3

finish_recvfrom = args[2]

class OSErrorMock(Exception):
def __init__(self, winerror):
self.winerror = winerror

with patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch.OSError', 'OSErrorMock'):

# Should raise ConnectionResetError if ov.getresult() raises OSError with winerror=ERROR_NETNAME_DELETED

ov.getresult.assert_not_called()
ov.getresult.side_effect = OSErrorMock(ERROR_NETNAME_DELETED)
with pytest.raises(ConnectionResetError):
finish_recvfrom(trans, key, ov, error_class=OSErrorMock)

# Should raise ConnectionResetError if ov.getresult() raises OSError with winerror=ERROR_OPERATION_ABORTED

ov.getresult.side_effect = OSErrorMock(ERROR_OPERATION_ABORTED)
with pytest.raises(ConnectionResetError):
finish_recvfrom(trans, key, ov, error_class=OSErrorMock)

# Should return empty result if ov.getresult() raises OSError with winerror=ERROR_PORT_UNREACHABLE

ov.getresult.side_effect = OSErrorMock(ERROR_PORT_UNREACHABLE)
result = finish_recvfrom(trans, key, ov, error_class=OSErrorMock)
assert result == (b'', None)

# Should reraise any other OSError raised by ov.getresult()

ov.getresult.side_effect = OSErrorMock(-1)
with pytest.raises(OSErrorMock):
finish_recvfrom(trans, key, ov, error_class=OSErrorMock)

# Should return result of ov.getresult() if no exceptions arisen

ov.getresult.side_effect = None
ov.getresult.return_value = Mock()
result = finish_recvfrom(trans, key, ov)
assert result is ov.getresult.return_value

assert ov.getresult.call_count == 5

0 comments on commit 4b454d5

Please sign in to comment.