diff --git a/src/run_tribler.py b/src/run_tribler.py index 51db913e779..be30dc51242 100644 --- a/src/run_tribler.py +++ b/src/run_tribler.py @@ -12,6 +12,7 @@ from tribler.core.sentry_reporter.sentry_reporter import SentryReporter, SentryStrategy from tribler.core.sentry_reporter.sentry_scrubber import SentryScrubber from tribler.core.utilities.asyncio_fixes.finish_accept_patch import apply_finish_accept_patch +from tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch import apply_proactor_recvfrom_patch from tribler.core.utilities.slow_coro_detection.main_thread_stack_tracking import start_main_thread_stack_tracing from tribler.core.utilities.osutils import get_root_state_directory from tribler.core.utilities.utilities import is_frozen @@ -95,6 +96,7 @@ def main(): if parsed_args.core: if sys.platform == 'win32': apply_finish_accept_patch() + apply_proactor_recvfrom_patch() from tribler.core.utilities.pony_utils import track_slow_db_sessions track_slow_db_sessions() diff --git a/src/tribler/core/utilities/asyncio_fixes/proactor_recvfrom_patch.py b/src/tribler/core/utilities/asyncio_fixes/proactor_recvfrom_patch.py new file mode 100644 index 00000000000..047c3a7232e --- /dev/null +++ b/src/tribler/core/utilities/asyncio_fixes/proactor_recvfrom_patch.py @@ -0,0 +1,58 @@ +from asyncio.log import logger + +try: + import _overlapped +except ImportError: + _overlapped = None + + +NULL = 0 + +ERROR_PORT_UNREACHABLE = 1234 # _overlapped.ERROR_PORT_UNREACHABLE, available in Python >= 3.11 +ERROR_NETNAME_DELETED = 64 +ERROR_OPERATION_ABORTED = 995 + +patch_applied = False + + +def apply_proactor_recvfrom_patch(): # pragma: no cover + global patch_applied # pylint: disable=global-statement + if patch_applied: + return + + from asyncio import IocpProactor # pylint: disable=import-outside-toplevel + + IocpProactor.recvfrom = patched_recvfrom + + patch_applied = True + logger.info("Patched IocpProactor.recvfrom to handle ERROR_PORT_UNREACHABLE") + + +# pylint: disable=protected-access + + +def patched_recvfrom(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + ov.WSARecvFrom(conn.fileno(), nbytes, flags) + except BrokenPipeError: + return self._result((b'', None)) + + def finish_recvfrom(trans, key, ov, error_class=OSError): # pylint: disable=unused-argument + try: + return ov.getresult() + except error_class as exc: + if exc.winerror in (ERROR_NETNAME_DELETED, ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) # pylint: disable=raise-missing-from + + # ******************** START OF THE PATCH ******************** + # WSARecvFrom will report ERROR_PORT_UNREACHABLE when the same + # socket was used to send to an address that is not listening. + if exc.winerror == ERROR_PORT_UNREACHABLE: + return b'', None # ignore the error + # ******************** END OF THE PATCH ********************** + + raise + + return self._register(ov, conn, finish_recvfrom) diff --git a/src/tribler/core/utilities/asyncio_fixes/tests/test_proactor_recvfrom_patch.py b/src/tribler/core/utilities/asyncio_fixes/tests/test_proactor_recvfrom_patch.py new file mode 100644 index 00000000000..5f3b9af153e --- /dev/null +++ b/src/tribler/core/utilities/asyncio_fixes/tests/test_proactor_recvfrom_patch.py @@ -0,0 +1,82 @@ +from unittest.mock import Mock, patch + +import pytest + +from tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch import ERROR_NETNAME_DELETED, \ + ERROR_OPERATION_ABORTED, ERROR_PORT_UNREACHABLE, patched_recvfrom + + +# pylint: disable=protected-access + + +@patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch._overlapped') +def test_patched_recvfrom_broken_pipe_error(overlapped): + proactor, conn, nbytes, flags, ov = (Mock() for _ in range(5)) + overlapped.Overlapped.return_value = ov + conn.fileno.return_value = Mock() + ov.WSARecvFrom.side_effect = BrokenPipeError() + proactor._result.return_value = Mock() + + result = patched_recvfrom(proactor, conn, nbytes, flags) + + proactor._register_with_iocp.assert_called_with(conn) + overlapped.Overlapped.assert_called_with(0) + ov.WSARecvFrom.assert_called_with(conn.fileno.return_value, nbytes, flags) + proactor._result.assert_called_with((b'', None)) + assert result is proactor._result.return_value + + +@patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch._overlapped') +def test_patched_recvfrom(overlapped): + proactor, conn, nbytes, flags, ov, trans, key = (Mock() for _ in range(7)) + overlapped.Overlapped.return_value = ov + conn.fileno.return_value = Mock() + proactor._register.return_value = Mock() + + result = patched_recvfrom(proactor, conn, nbytes, flags) + proactor._register.assert_called_once() + assert result is proactor._register.return_value + args = proactor._register.call_args.args + assert args[:2] == (ov, conn) and len(args) == 3 + + finish_recvfrom = args[2] + + class OSErrorMock(Exception): + def __init__(self, winerror): + self.winerror = winerror + + with patch('tribler.core.utilities.asyncio_fixes.proactor_recvfrom_patch.OSError', 'OSErrorMock'): + + # Should raise ConnectionResetError if ov.getresult() raises OSError with winerror=ERROR_NETNAME_DELETED + + ov.getresult.assert_not_called() + ov.getresult.side_effect = OSErrorMock(ERROR_NETNAME_DELETED) + with pytest.raises(ConnectionResetError): + finish_recvfrom(trans, key, ov, error_class=OSErrorMock) + + # Should raise ConnectionResetError if ov.getresult() raises OSError with winerror=ERROR_OPERATION_ABORTED + + ov.getresult.side_effect = OSErrorMock(ERROR_OPERATION_ABORTED) + with pytest.raises(ConnectionResetError): + finish_recvfrom(trans, key, ov, error_class=OSErrorMock) + + # Should return empty result if ov.getresult() raises OSError with winerror=ERROR_PORT_UNREACHABLE + + ov.getresult.side_effect = OSErrorMock(ERROR_PORT_UNREACHABLE) + result = finish_recvfrom(trans, key, ov, error_class=OSErrorMock) + assert result == (b'', None) + + # Should reraise any other OSError raised by ov.getresult() + + ov.getresult.side_effect = OSErrorMock(-1) + with pytest.raises(OSErrorMock): + finish_recvfrom(trans, key, ov, error_class=OSErrorMock) + + # Should return result of ov.getresult() if no exceptions arisen + + ov.getresult.side_effect = None + ov.getresult.return_value = Mock() + result = finish_recvfrom(trans, key, ov) + assert result is ov.getresult.return_value + + assert ov.getresult.call_count == 5