From 160ec7ae19ade25bbc8d2417ffd66f78c03b8de5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=2E=20K=C3=A4rkk=C3=A4inen?= <98187+Tronic@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:40:44 +0000 Subject: [PATCH] Make request.scheme return ws/wss for WS even when http/https in SERVER_NAME or proxy headers (#2854) Co-authored-by: L. Karkkainen Co-authored-by: Adam Hopkins --- sanic/app.py | 3 ++ sanic/request/types.py | 22 ++++---- tests/test_ws_handlers.py | 108 +++++++++++++++++++++++--------------- 3 files changed, 82 insertions(+), 51 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 14f5bdc533..d0b1c8ffea 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1066,6 +1066,9 @@ def url_for(self, view_name: str, **kwargs): scheme = netloc[:8].split(":", 1)[0] else: scheme = "http" + # Replace http/https with ws/wss for WebSocket handlers + if route.extra.websocket: + scheme = scheme.replace("http", "ws") if "://" in netloc[:8]: netloc = netloc.split("://", 1)[-1] diff --git a/sanic/request/types.py b/sanic/request/types.py index 45fbd2088c..ec4c2a822d 100644 --- a/sanic/request/types.py +++ b/sanic/request/types.py @@ -964,20 +964,23 @@ def scheme(self) -> str: str: http|https|ws|wss or arbitrary value given by the headers. """ if not hasattr(self, "_scheme"): - if "//" in self.app.config.get("SERVER_NAME", ""): - return self.app.config.SERVER_NAME.split("//")[0] - if "proto" in self.forwarded: - return str(self.forwarded["proto"]) - if ( self.app.websocket_enabled - and self.headers.getone("upgrade", "").lower() == "websocket" + and self.headers.upgrade.lower() == "websocket" ): scheme = "ws" else: scheme = "http" - - if self.transport.get_extra_info("sslcontext"): + proto = None + sp = self.app.config.get("SERVER_NAME", "").split("://", 1) + if len(sp) == 2: + proto = sp[0] + elif "proto" in self.forwarded: + proto = str(self.forwarded["proto"]) + if proto: + # Give ws/wss if websocket, otherwise keep the same + scheme = proto.replace("http", scheme) + elif self.conn_info and self.conn_info.ssl: scheme += "s" self._scheme = scheme @@ -1072,7 +1075,8 @@ def url_for(self, view_name: str, **kwargs) -> str: """ # Full URL SERVER_NAME can only be handled in app.url_for try: - if "//" in self.app.config.SERVER_NAME: + sp = self.app.config.get("SERVER_NAME", "").split("://", 1) + if len(sp) == 2: return self.app.url_for(view_name, _external=True, **kwargs) except AttributeError: pass diff --git a/tests/test_ws_handlers.py b/tests/test_ws_handlers.py index c70dbeed97..6236292402 100644 --- a/tests/test_ws_handlers.py +++ b/tests/test_ws_handlers.py @@ -7,9 +7,7 @@ from sanic import Request, Sanic, Websocket -MimicClientType = Callable[ - [WebSocketClientProtocol], Coroutine[None, None, Any] -] +MimicClientType = Callable[[WebSocketClientProtocol], Coroutine[None, None, Any]] @pytest.fixture @@ -23,39 +21,6 @@ async def client_mimic(ws: WebSocketClientProtocol): return client_mimic -def test_ws_handler( - app: Sanic, - simple_ws_mimic_client: MimicClientType, -): - @app.websocket("/ws") - async def ws_echo_handler(request: Request, ws: Websocket): - while True: - msg = await ws.recv() - await ws.send(msg) - - _, ws_proxy = app.test_client.websocket( - "/ws", mimic=simple_ws_mimic_client - ) - assert ws_proxy.client_sent == ["test 1", "test 2", ""] - assert ws_proxy.client_received == ["test 1", "test 2"] - - -def test_ws_handler_async_for( - app: Sanic, - simple_ws_mimic_client: MimicClientType, -): - @app.websocket("/ws") - async def ws_echo_handler(request: Request, ws: Websocket): - async for msg in ws: - await ws.send(msg) - - _, ws_proxy = app.test_client.websocket( - "/ws", mimic=simple_ws_mimic_client - ) - assert ws_proxy.client_sent == ["test 1", "test 2", ""] - assert ws_proxy.client_received == ["test 1", "test 2"] - - def signalapp(app): @app.signal("websocket.handler.before") async def ws_before(request: Request, websocket: Websocket): @@ -90,6 +55,69 @@ async def ws_error(request: Request, ws: Websocket): print("wserr2") +def test_ws_handler( + app: Sanic, + simple_ws_mimic_client: MimicClientType, +): + @app.websocket("/ws") + async def ws_echo_handler(request: Request, ws: Websocket): + while True: + msg = await ws.recv() + await ws.send(msg) + + _, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client) + assert ws_proxy.client_sent == ["test 1", "test 2", ""] + assert ws_proxy.client_received == ["test 1", "test 2"] + + +def test_ws_handler_async_for( + app: Sanic, + simple_ws_mimic_client: MimicClientType, +): + @app.websocket("/ws") + async def ws_echo_handler(request: Request, ws: Websocket): + async for msg in ws: + await ws.send(msg) + + _, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client) + assert ws_proxy.client_sent == ["test 1", "test 2", ""] + assert ws_proxy.client_received == ["test 1", "test 2"] + + +@pytest.mark.parametrize("proxy", ["", "proxy", "servername"]) +def test_request_url( + app: Sanic, + simple_ws_mimic_client: MimicClientType, + proxy: str, +): + @app.websocket("/ws") + async def ws_url_handler(request: Request, ws: Websocket): + request.headers[ + "forwarded" + ] = "for=[2001:db8::1];proto=https;host=example.com;by=proxy" + + await ws.recv() + await ws.send(request.url) + await ws.recv() + await ws.send(request.url_for("ws_url_handler")) + await ws.recv() + + app.config.FORWARDED_SECRET = proxy + app.config.SERVER_NAME = "https://example.com" if proxy == "servername" else "" + _, ws_proxy = app.test_client.websocket( + "/ws", + mimic=simple_ws_mimic_client, + ) + assert ws_proxy.client_sent == ["test 1", "test 2", ""] + assert ws_proxy.client_received[0] == ws_proxy.client_received[1] + if proxy: + assert ws_proxy.client_received[0] == "wss://example.com/ws" + assert ws_proxy.client_received[1] == "wss://example.com/ws" + else: + assert ws_proxy.client_received[0].startswith("ws://127.0.0.1") + assert ws_proxy.client_received[1].startswith("ws://127.0.0.1") + + def test_ws_signals( app: Sanic, simple_ws_mimic_client: MimicClientType, @@ -97,9 +125,7 @@ def test_ws_signals( signalapp(app) app.ctx.seq = [] - _, ws_proxy = app.test_client.websocket( - "/ws", mimic=simple_ws_mimic_client - ) + _, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client) assert ws_proxy.client_received == ["before: test 1", "after: test 2"] assert app.ctx.seq == ["before", "ws", "after"] @@ -111,8 +137,6 @@ def test_ws_signals_exception( signalapp(app) app.ctx.seq = [] - _, ws_proxy = app.test_client.websocket( - "/wserror", mimic=simple_ws_mimic_client - ) + _, ws_proxy = app.test_client.websocket("/wserror", mimic=simple_ws_mimic_client) assert ws_proxy.client_received == ["before: test 1", "exception: test 2"] assert app.ctx.seq == ["before", "wserror", "exception"]