From f7b16a343d06c73832db0cdf2821af6d3df251ea Mon Sep 17 00:00:00 2001 From: Matth Date: Tue, 11 Jun 2024 15:17:50 +0200 Subject: [PATCH 1/5] Add headers to ws connection --- src/roslibpy/comm/comm_autobahn.py | 16 +++++++++++++++- src/roslibpy/ros.py | 5 +++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/roslibpy/comm/comm_autobahn.py b/src/roslibpy/comm/comm_autobahn.py index 3a830fd..737af4b 100644 --- a/src/roslibpy/comm/comm_autobahn.py +++ b/src/roslibpy/comm/comm_autobahn.py @@ -23,10 +23,17 @@ class AutobahnRosBridgeProtocol(RosBridgeProtocol, WebSocketClientProtocol): def __init__(self, *args, **kwargs): super(AutobahnRosBridgeProtocol, self).__init__(*args, **kwargs) + self.headers = {} def onConnect(self, response): LOGGER.debug("Server connected: %s", response.peer) + def getHandshakeRequestHeaders(self): + headers = super(AutobahnRosBridgeProtocol, self).getHandshakeRequestHeaders() + for key, value in self.headers.items(): + headers.append((key, value)) + return headers + def onOpen(self): LOGGER.info("Connection to ROS ready.") self._manual_disconnect = False @@ -62,13 +69,20 @@ class AutobahnRosBridgeClientFactory(EventEmitterMixin, ReconnectingClientFactor protocol = AutobahnRosBridgeProtocol - def __init__(self, *args, **kwargs): + def __init__(self, *args, headers=None, **kwargs): super(AutobahnRosBridgeClientFactory, self).__init__(*args, **kwargs) + self.headers = headers or {} self._proto = None self._manager = None self.connector = None self.setProtocolOptions(closeHandshakeTimeout=5) + def buildProtocol(self, addr): + proto = self.protocol() + proto.factory = self + proto.headers = self.headers + return proto + def connect(self): """Establish WebSocket connection to the ROS server defined for this factory.""" self.connector = connectWS(self) diff --git a/src/roslibpy/ros.py b/src/roslibpy/ros.py index 72ff7d6..c69822e 100644 --- a/src/roslibpy/ros.py +++ b/src/roslibpy/ros.py @@ -32,12 +32,13 @@ class Ros(object): host (:obj:`str`): Name or IP address of the ROS bridge host, e.g. ``127.0.0.1``. port (:obj:`int`): ROS bridge port, e.g. ``9090``. is_secure (:obj:`bool`): ``True`` to use a secure web sockets connection, otherwise ``False``. + headers (:obj:`dict`): Additional headers to include in the WebSocket connection. """ - def __init__(self, host, port=None, is_secure=False): + def __init__(self, host, port=None, is_secure=False, headers=None): self._id_counter = 0 url = RosBridgeClientFactory.create_url(host, port, is_secure) - self.factory = RosBridgeClientFactory(url) + self.factory = RosBridgeClientFactory(url, headers=headers) self.is_connecting = False self.connect() From 75a0f5cb6a686d5d3d3ac88c38e3ea4c0085f7e9 Mon Sep 17 00:00:00 2001 From: Matth Date: Mon, 17 Jun 2024 09:09:42 +0200 Subject: [PATCH 2/5] Add websocket header test --- tests/test_ws_headers.py | 56 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/test_ws_headers.py diff --git a/tests/test_ws_headers.py b/tests/test_ws_headers.py new file mode 100644 index 0000000..726a3e1 --- /dev/null +++ b/tests/test_ws_headers.py @@ -0,0 +1,56 @@ +from __future__ import print_function +import threading +import time + +from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory +from twisted.internet import reactor + +from roslibpy import Ros + +headers = { + 'cookie': 'token=rosbridge', + 'authorization': 'Some auth' +} + +class TestWebSocketServerProtocol(WebSocketServerProtocol): + def onConnect(self, request): + for key, value in headers.items(): + assert request.headers.get(key) == value, f"Header {key} did not match expected value {value}" + self.factory.context['wait'].set() + + def onOpen(self): + self.sendClose() + +def run_server(context): + factory = WebSocketServerFactory() + factory.protocol = TestWebSocketServerProtocol + factory.context = context + + reactor.listenTCP(9000, factory) + reactor.run(installSignalHandlers=False) + +def run_client(): + client = Ros('127.0.0.1', 9000, headers=headers) + client.run() + client.close() + +def test_websocket_headers(): + context = dict(wait=threading.Event()) + + server_thread = threading.Thread(target=run_server, args=(context,)) + server_thread.start() + + time.sleep(1) # Give the server time to start + + client_thread = threading.Thread(target=run_client) + client_thread.start() + + if not context["wait"].wait(10): + raise Exception("Headers were not as expected") + + client_thread.join() + reactor.callFromThread(reactor.stop) + server_thread.join() + +if __name__ == "__main__": + test_websocket_headers() From 54967199c85edd579d25a0bc608cedc87069235f Mon Sep 17 00:00:00 2001 From: Matth Date: Mon, 17 Jun 2024 10:14:00 +0200 Subject: [PATCH 3/5] Make the test actually work --- requirements-dev.txt | 1 + tests/test_ws_headers.py | 58 +++++++++++++++++++++++----------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index e670db4..a154efc 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,4 +10,5 @@ pydocstyle pytest>=6.0 sphinx >=3.4 twine +websockets >= 12.0 -e . diff --git a/tests/test_ws_headers.py b/tests/test_ws_headers.py index 726a3e1..5e3d0a1 100644 --- a/tests/test_ws_headers.py +++ b/tests/test_ws_headers.py @@ -1,9 +1,10 @@ from __future__ import print_function + +import asyncio import threading import time -from autobahn.twisted.websocket import WebSocketServerProtocol, WebSocketServerFactory -from twisted.internet import reactor +import websockets from roslibpy import Ros @@ -12,32 +13,36 @@ 'authorization': 'Some auth' } -class TestWebSocketServerProtocol(WebSocketServerProtocol): - def onConnect(self, request): - for key, value in headers.items(): - assert request.headers.get(key) == value, f"Header {key} did not match expected value {value}" - self.factory.context['wait'].set() - def onOpen(self): - self.sendClose() +async def websocket_handler(websocket, path): + request_headers = websocket.request_headers + for key, value in headers.items(): + assert request_headers.get(key) == value, f"Header {key} did not match expected value {value}" + await websocket.close() + + +async def start_server(stop_event): + server = await websockets.serve(websocket_handler, '127.0.0.1', 9000) + await stop_event.wait() + server.close() + await server.wait_closed() + -def run_server(context): - factory = WebSocketServerFactory() - factory.protocol = TestWebSocketServerProtocol - factory.context = context +def run_server(stop_event): + asyncio.run(start_server(stop_event)) - reactor.listenTCP(9000, factory) - reactor.run(installSignalHandlers=False) def run_client(): client = Ros('127.0.0.1', 9000, headers=headers) client.run() client.close() + def test_websocket_headers(): - context = dict(wait=threading.Event()) + server_stop_event = asyncio.Event() + stop_event = threading.Event() - server_thread = threading.Thread(target=run_server, args=(context,)) + server_thread = threading.Thread(target=run_server, args=(server_stop_event,)) server_thread.start() time.sleep(1) # Give the server time to start @@ -45,12 +50,17 @@ def test_websocket_headers(): client_thread = threading.Thread(target=run_client) client_thread.start() - if not context["wait"].wait(10): - raise Exception("Headers were not as expected") + # Wait for the client thread to finish or timeout after 10 seconds + client_thread.join(timeout=10) + + if client_thread.is_alive(): + raise Exception("Client did not terminate as expected") + + # Signal the server to stop + server_stop_event.set() + server_thread.join(timeout=10) - client_thread.join() - reactor.callFromThread(reactor.stop) - server_thread.join() + if server_thread.is_alive(): + raise Exception("Server did not stop as expected") -if __name__ == "__main__": - test_websocket_headers() + stop_event.set() From bfe733b238c3ad1113bce4a245e186d6e8cde380 Mon Sep 17 00:00:00 2001 From: Matth Date: Mon, 17 Jun 2024 10:34:16 +0200 Subject: [PATCH 4/5] Remove unnecessary code --- src/roslibpy/comm/comm_autobahn.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/roslibpy/comm/comm_autobahn.py b/src/roslibpy/comm/comm_autobahn.py index 737af4b..3a830fd 100644 --- a/src/roslibpy/comm/comm_autobahn.py +++ b/src/roslibpy/comm/comm_autobahn.py @@ -23,17 +23,10 @@ class AutobahnRosBridgeProtocol(RosBridgeProtocol, WebSocketClientProtocol): def __init__(self, *args, **kwargs): super(AutobahnRosBridgeProtocol, self).__init__(*args, **kwargs) - self.headers = {} def onConnect(self, response): LOGGER.debug("Server connected: %s", response.peer) - def getHandshakeRequestHeaders(self): - headers = super(AutobahnRosBridgeProtocol, self).getHandshakeRequestHeaders() - for key, value in self.headers.items(): - headers.append((key, value)) - return headers - def onOpen(self): LOGGER.info("Connection to ROS ready.") self._manual_disconnect = False @@ -69,20 +62,13 @@ class AutobahnRosBridgeClientFactory(EventEmitterMixin, ReconnectingClientFactor protocol = AutobahnRosBridgeProtocol - def __init__(self, *args, headers=None, **kwargs): + def __init__(self, *args, **kwargs): super(AutobahnRosBridgeClientFactory, self).__init__(*args, **kwargs) - self.headers = headers or {} self._proto = None self._manager = None self.connector = None self.setProtocolOptions(closeHandshakeTimeout=5) - def buildProtocol(self, addr): - proto = self.protocol() - proto.factory = self - proto.headers = self.headers - return proto - def connect(self): """Establish WebSocket connection to the ROS server defined for this factory.""" self.connector = connectWS(self) From e005dfa0bf6ffd2e4d44f60cc2ddbf82277bfc10 Mon Sep 17 00:00:00 2001 From: Matth Date: Mon, 17 Jun 2024 11:00:25 +0200 Subject: [PATCH 5/5] Add change to changelog --- CHANGELOG.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 067e65a..843f053 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,8 @@ Unreleased **Added** +* Added websocket header support to the ROS-client. + **Changed** **Fixed**