diff --git a/src/engineio/async_client.py b/src/engineio/async_client.py index 5e5980f..07fe2df 100644 --- a/src/engineio/async_client.py +++ b/src/engineio/async_client.py @@ -71,7 +71,6 @@ class AsyncClient(base_client.BaseClient): arguments passed to ``aiohttp.ws_connect()``. """ - def is_asyncio_based(self): return True diff --git a/src/engineio/base_client.py b/src/engineio/base_client.py index 5b2eac1..6381be2 100644 --- a/src/engineio/base_client.py +++ b/src/engineio/base_client.py @@ -1,4 +1,6 @@ import logging +import signal +import threading import time import urllib from . import packet @@ -7,12 +9,35 @@ connected_clients = [] +def signal_handler(sig, frame): + """SIGINT handler. + + Disconnect all active clients and then invoke the original signal handler. + """ + for client in connected_clients[:]: + if not client.is_asyncio_based(): + client.disconnect() + if callable(original_signal_handler): + return original_signal_handler(sig, frame) + else: # pragma: no cover + # Handle case where no original SIGINT handler was present. + return signal.default_int_handler(sig, frame) + + +original_signal_handler = None + + class BaseClient: event_names = ['connect', 'disconnect', 'message'] def __init__(self, logger=False, json=None, request_timeout=5, http_session=None, ssl_verify=True, handle_sigint=True, websocket_extra_options=None): + global original_signal_handler + if handle_sigint and original_signal_handler is None and \ + threading.current_thread() == threading.main_thread(): + original_signal_handler = signal.signal(signal.SIGINT, + signal_handler) self.handlers = {} self.base_url = None self.transports = None diff --git a/src/engineio/client.py b/src/engineio/client.py index d387b62..51207a1 100644 --- a/src/engineio/client.py +++ b/src/engineio/client.py @@ -2,7 +2,6 @@ from engineio.json import JSONDecodeError import logging import queue -import signal import ssl import threading import time @@ -24,24 +23,6 @@ default_logger = logging.getLogger('engineio.client') -def signal_handler(sig, frame): - """SIGINT handler. - - Disconnect all active clients and then invoke the original signal handler. - """ - for client in base_client.connected_clients[:]: - if not client.is_asyncio_based(): - client.disconnect() - if callable(original_signal_handler): - return original_signal_handler(sig, frame) - else: # pragma: no cover - # Handle case where no original SIGINT handler was present. - return signal.default_int_handler(sig, frame) - - -original_signal_handler = None - - class Client(base_client.BaseClient): """An Engine.IO client. @@ -75,22 +56,6 @@ class Client(base_client.BaseClient): arguments passed to ``websocket.create_connection()``. """ - event_names = ['connect', 'disconnect', 'message'] - - def __init__(self, logger=False, json=None, request_timeout=5, - http_session=None, ssl_verify=True, handle_sigint=True, - websocket_extra_options=None): - global original_signal_handler - if handle_sigint and original_signal_handler is None and \ - threading.current_thread() == threading.main_thread(): - original_signal_handler = signal.signal(signal.SIGINT, - signal_handler) - super().__init__(logger=logger, json=json, - request_timeout=request_timeout, - http_session=http_session, ssl_verify=ssl_verify, - handle_sigint=handle_sigint, - websocket_extra_options=websocket_extra_options) - def connect(self, url, headers=None, transports=None, engineio_path='engine.io'): """Connect to an Engine.IO server. diff --git a/tests/async/test_client.py b/tests/async/test_client.py index 665a3a4..d71f664 100644 --- a/tests/async/test_client.py +++ b/tests/async/test_client.py @@ -1473,7 +1473,7 @@ def test_write_loop_websocket_bad_connection(self): _run(c._write_loop()) assert c.state == 'connected' - @mock.patch('engineio.client.original_signal_handler') + @mock.patch('engineio.base_client.original_signal_handler') def test_signal_handler(self, original_handler): clients = [mock.MagicMock(), mock.MagicMock()] base_client.connected_clients = clients[:] diff --git a/tests/common/test_client.py b/tests/common/test_client.py index dacc4b4..4f45932 100644 --- a/tests/common/test_client.py +++ b/tests/common/test_client.py @@ -1715,12 +1715,12 @@ def test_write_loop_websocket_bad_connection(self): c._write_loop() assert c.state == 'connected' - @mock.patch('engineio.client.original_signal_handler') + @mock.patch('engineio.base_client.original_signal_handler') def test_signal_handler(self, original_handler): clients = [mock.MagicMock(), mock.MagicMock()] base_client.connected_clients = clients[:] base_client.connected_clients[0].is_asyncio_based.return_value = False base_client.connected_clients[1].is_asyncio_based.return_value = True - client.signal_handler('sig', 'frame') + base_client.signal_handler('sig', 'frame') clients[0].disconnect.assert_called_once_with() clients[1].disconnect.assert_not_called()