Skip to content

Commit

Permalink
Make request.scheme return ws/wss for WS even when http/https in SERV…
Browse files Browse the repository at this point in the history
…ER_NAME or proxy headers (#2854)

Co-authored-by: L. Karkkainen <tronic@users.noreply.github.com>
Co-authored-by: Adam Hopkins <adam@amhopkins.com>
  • Loading branch information
3 people committed Dec 7, 2023
1 parent 5f0787b commit 160ec7a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 51 deletions.
3 changes: 3 additions & 0 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,9 @@ def url_for(self, view_name: str, **kwargs):
scheme = netloc[:8].split(":", 1)[0]
else:
scheme = "http"
# Replace http/https with ws/wss for WebSocket handlers
if route.extra.websocket:
scheme = scheme.replace("http", "ws")

if "://" in netloc[:8]:
netloc = netloc.split("://", 1)[-1]
Expand Down
22 changes: 13 additions & 9 deletions sanic/request/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,20 +964,23 @@ def scheme(self) -> str:
str: http|https|ws|wss or arbitrary value given by the headers.
"""
if not hasattr(self, "_scheme"):
if "//" in self.app.config.get("SERVER_NAME", ""):
return self.app.config.SERVER_NAME.split("//")[0]
if "proto" in self.forwarded:
return str(self.forwarded["proto"])

if (
self.app.websocket_enabled
and self.headers.getone("upgrade", "").lower() == "websocket"
and self.headers.upgrade.lower() == "websocket"
):
scheme = "ws"
else:
scheme = "http"

if self.transport.get_extra_info("sslcontext"):
proto = None
sp = self.app.config.get("SERVER_NAME", "").split("://", 1)
if len(sp) == 2:
proto = sp[0]
elif "proto" in self.forwarded:
proto = str(self.forwarded["proto"])
if proto:
# Give ws/wss if websocket, otherwise keep the same
scheme = proto.replace("http", scheme)
elif self.conn_info and self.conn_info.ssl:
scheme += "s"
self._scheme = scheme

Expand Down Expand Up @@ -1072,7 +1075,8 @@ def url_for(self, view_name: str, **kwargs) -> str:
"""
# Full URL SERVER_NAME can only be handled in app.url_for
try:
if "//" in self.app.config.SERVER_NAME:
sp = self.app.config.get("SERVER_NAME", "").split("://", 1)
if len(sp) == 2:
return self.app.url_for(view_name, _external=True, **kwargs)
except AttributeError:
pass
Expand Down
108 changes: 66 additions & 42 deletions tests/test_ws_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from sanic import Request, Sanic, Websocket


MimicClientType = Callable[
[WebSocketClientProtocol], Coroutine[None, None, Any]
]
MimicClientType = Callable[[WebSocketClientProtocol], Coroutine[None, None, Any]]


@pytest.fixture
Expand All @@ -23,39 +21,6 @@ async def client_mimic(ws: WebSocketClientProtocol):
return client_mimic


def test_ws_handler(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
while True:
msg = await ws.recv()
await ws.send(msg)

_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


def test_ws_handler_async_for(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
async for msg in ws:
await ws.send(msg)

_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


def signalapp(app):
@app.signal("websocket.handler.before")
async def ws_before(request: Request, websocket: Websocket):
Expand Down Expand Up @@ -90,16 +55,77 @@ async def ws_error(request: Request, ws: Websocket):
print("wserr2")


def test_ws_handler(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
while True:
msg = await ws.recv()
await ws.send(msg)

_, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


def test_ws_handler_async_for(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
@app.websocket("/ws")
async def ws_echo_handler(request: Request, ws: Websocket):
async for msg in ws:
await ws.send(msg)

_, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received == ["test 1", "test 2"]


@pytest.mark.parametrize("proxy", ["", "proxy", "servername"])
def test_request_url(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
proxy: str,
):
@app.websocket("/ws")
async def ws_url_handler(request: Request, ws: Websocket):
request.headers[
"forwarded"
] = "for=[2001:db8::1];proto=https;host=example.com;by=proxy"

await ws.recv()
await ws.send(request.url)
await ws.recv()
await ws.send(request.url_for("ws_url_handler"))
await ws.recv()

app.config.FORWARDED_SECRET = proxy
app.config.SERVER_NAME = "https://example.com" if proxy == "servername" else ""
_, ws_proxy = app.test_client.websocket(
"/ws",
mimic=simple_ws_mimic_client,
)
assert ws_proxy.client_sent == ["test 1", "test 2", ""]
assert ws_proxy.client_received[0] == ws_proxy.client_received[1]
if proxy:
assert ws_proxy.client_received[0] == "wss://example.com/ws"
assert ws_proxy.client_received[1] == "wss://example.com/ws"
else:
assert ws_proxy.client_received[0].startswith("ws://127.0.0.1")
assert ws_proxy.client_received[1].startswith("ws://127.0.0.1")


def test_ws_signals(
app: Sanic,
simple_ws_mimic_client: MimicClientType,
):
signalapp(app)

app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/ws", mimic=simple_ws_mimic_client
)
_, ws_proxy = app.test_client.websocket("/ws", mimic=simple_ws_mimic_client)
assert ws_proxy.client_received == ["before: test 1", "after: test 2"]
assert app.ctx.seq == ["before", "ws", "after"]

Expand All @@ -111,8 +137,6 @@ def test_ws_signals_exception(
signalapp(app)

app.ctx.seq = []
_, ws_proxy = app.test_client.websocket(
"/wserror", mimic=simple_ws_mimic_client
)
_, ws_proxy = app.test_client.websocket("/wserror", mimic=simple_ws_mimic_client)
assert ws_proxy.client_received == ["before: test 1", "exception: test 2"]
assert app.ctx.seq == ["before", "wserror", "exception"]

0 comments on commit 160ec7a

Please sign in to comment.