Skip to content

Commit

Permalink
Document & test process_response modifying the response.
Browse files Browse the repository at this point in the history
a78b554 inadvertently changed the test from "returning a new response"
to "modifying the existing response". Both are supported..
  • Loading branch information
aaugustin committed Aug 20, 2024
1 parent 09b1d8d commit 8eaa5a2
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 44 deletions.
11 changes: 6 additions & 5 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,16 @@ class Server:
handler: Connection handler. It receives the WebSocket connection,
which is a :class:`ServerConnection`, in argument.
process_request: Intercept the request during the opening handshake.
Return an HTTP response to force the response or :obj:`None` to
Return an HTTP response to force the response. Return :obj:`None` to
continue normally. When you force an HTTP 101 Continue response, the
handshake is successful. Else, the connection is aborted.
``process_request`` may be a function or a coroutine.
process_response: Intercept the response during the opening handshake.
Return an HTTP response to force the response or :obj:`None` to
continue normally. When you force an HTTP 101 Continue response, the
handshake is successful. Else, the connection is aborted.
``process_response`` may be a function or a coroutine.
Modify the response or return a new HTTP response to force the
response. Return :obj:`None` to continue normally. When you force an
HTTP 101 Continue response, the handshake is successful. Else, the
connection is aborted. ``process_response`` may be a function or a
coroutine.
server_header: Value of the ``Server`` response header.
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
:obj:`None` removes the header.
Expand Down
13 changes: 7 additions & 6 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,14 @@ def handler(websocket):
:meth:`ServerProtocol.select_subprotocol
<websockets.server.ServerProtocol.select_subprotocol>` method.
process_request: Intercept the request during the opening handshake.
Return an HTTP response to force the response or :obj:`None` to
continue normally. When you force an HTTP 101 Continue response,
the handshake is successful. Else, the connection is aborted.
Return an HTTP response to force the response. Return :obj:`None` to
continue normally. When you force an HTTP 101 Continue response, the
handshake is successful. Else, the connection is aborted.
process_response: Intercept the response during the opening handshake.
Return an HTTP response to force the response or :obj:`None` to
continue normally. When you force an HTTP 101 Continue response,
the handshake is successful. Else, the connection is aborted.
Modify the response or return a new HTTP response to force the
response. Return :obj:`None` to continue normally. When you force an
HTTP 101 Continue response, the handshake is successful. Else, the
connection is aborted.
server_header: Value of the ``Server`` response header.
It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to
:obj:`None` removes the header.
Expand Down
65 changes: 43 additions & 22 deletions tests/asyncio/test_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import dataclasses
import http
import logging
import socket
Expand Down Expand Up @@ -117,8 +118,8 @@ def select_subprotocol(ws, subprotocols):
"server rejected WebSocket connection: HTTP 500",
)

async def test_process_request(self):
"""Server runs process_request before processing the handshake."""
async def test_process_request_returns_none(self):
"""Server runs process_request and continues the handshake."""

def process_request(ws, request):
self.assertIsInstance(request, Request)
Expand All @@ -128,8 +129,8 @@ def process_request(ws, request):
async with run_client(server) as client:
await self.assertEval(client, "ws.process_request_ran", "True")

async def test_async_process_request(self):
"""Server runs async process_request before processing the handshake."""
async def test_async_process_request_returns_none(self):
"""Server runs async process_request and continues the handshake."""

async def process_request(ws, request):
self.assertIsInstance(request, Request)
Expand All @@ -139,7 +140,7 @@ async def process_request(ws, request):
async with run_client(server) as client:
await self.assertEval(client, "ws.process_request_ran", "True")

async def test_process_request_abort_handshake(self):
async def test_process_request_returns_response(self):
"""Server aborts handshake if process_request returns a response."""

def process_request(ws, request):
Expand All @@ -154,7 +155,7 @@ def process_request(ws, request):
"server rejected WebSocket connection: HTTP 403",
)

async def test_async_process_request_abort_handshake(self):
async def test_async_process_request_returns_response(self):
"""Server aborts handshake if async process_request returns a response."""

async def process_request(ws, request):
Expand Down Expand Up @@ -199,8 +200,8 @@ async def process_request(ws, request):
"server rejected WebSocket connection: HTTP 500",
)

async def test_process_response(self):
"""Server runs process_response after processing the handshake."""
async def test_process_response_returns_none(self):
"""Server runs process_response but keeps the handshake response."""

def process_response(ws, request, response):
self.assertIsInstance(request, Request)
Expand All @@ -211,8 +212,8 @@ def process_response(ws, request, response):
async with run_client(server) as client:
await self.assertEval(client, "ws.process_response_ran", "True")

async def test_async_process_response(self):
"""Server runs async process_response after processing the handshake."""
async def test_async_process_response_returns_none(self):
"""Server runs async process_response but keeps the handshake response."""

async def process_response(ws, request, response):
self.assertIsInstance(request, Request)
Expand All @@ -223,29 +224,49 @@ async def process_response(ws, request, response):
async with run_client(server) as client:
await self.assertEval(client, "ws.process_response_ran", "True")

async def test_process_response_override_response(self):
"""Server runs process_response and overrides the handshake response."""
async def test_process_response_modifies_response(self):
"""Server runs process_response and modifies the handshake response."""

def process_response(ws, request, response):
response.headers["X-ProcessResponse-Ran"] = "true"
response.headers["X-ProcessResponse"] = "OK"

async with run_server(process_response=process_response) as server:
async with run_client(server) as client:
self.assertEqual(
client.response.headers["X-ProcessResponse-Ran"], "true"
)
self.assertEqual(client.response.headers["X-ProcessResponse"], "OK")

async def test_async_process_response_override_response(self):
"""Server runs async process_response and overrides the handshake response."""
async def test_async_process_response_modifies_response(self):
"""Server runs async process_response and modifies the handshake response."""

async def process_response(ws, request, response):
response.headers["X-ProcessResponse-Ran"] = "true"
response.headers["X-ProcessResponse"] = "OK"

async with run_server(process_response=process_response) as server:
async with run_client(server) as client:
self.assertEqual(
client.response.headers["X-ProcessResponse-Ran"], "true"
)
self.assertEqual(client.response.headers["X-ProcessResponse"], "OK")

async def test_process_response_replaces_response(self):
"""Server runs process_response and replaces the handshake response."""

def process_response(ws, request, response):
headers = response.headers.copy()
headers["X-ProcessResponse"] = "OK"
return dataclasses.replace(response, headers=headers)

async with run_server(process_response=process_response) as server:
async with run_client(server) as client:
self.assertEqual(client.response.headers["X-ProcessResponse"], "OK")

async def test_async_process_response_replaces_response(self):
"""Server runs async process_response and replaces the handshake response."""

async def process_response(ws, request, response):
headers = response.headers.copy()
headers["X-ProcessResponse"] = "OK"
return dataclasses.replace(response, headers=headers)

async with run_server(process_response=process_response) as server:
async with run_client(server) as client:
self.assertEqual(client.response.headers["X-ProcessResponse"], "OK")

async def test_process_response_raises_exception(self):
"""Server returns an error if process_response raises an exception."""
Expand Down
33 changes: 22 additions & 11 deletions tests/sync/test_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import http
import logging
import socket
Expand Down Expand Up @@ -115,8 +116,8 @@ def select_subprotocol(ws, subprotocols):
"server rejected WebSocket connection: HTTP 500",
)

def test_process_request(self):
"""Server runs process_request before processing the handshake."""
def test_process_request_returns_none(self):
"""Server runs process_request and continues the handshake."""

def process_request(ws, request):
self.assertIsInstance(request, Request)
Expand All @@ -126,7 +127,7 @@ def process_request(ws, request):
with run_client(server) as client:
self.assertEval(client, "ws.process_request_ran", "True")

def test_process_request_abort_handshake(self):
def test_process_request_returns_response(self):
"""Server aborts handshake if process_request returns a response."""

def process_request(ws, request):
Expand Down Expand Up @@ -156,8 +157,8 @@ def process_request(ws, request):
"server rejected WebSocket connection: HTTP 500",
)

def test_process_response(self):
"""Server runs process_response after processing the handshake."""
def test_process_response_returns_none(self):
"""Server runs process_response but keeps the handshake response."""

def process_response(ws, request, response):
self.assertIsInstance(request, Request)
Expand All @@ -168,17 +169,27 @@ def process_response(ws, request, response):
with run_client(server) as client:
self.assertEval(client, "ws.process_response_ran", "True")

def test_process_response_override_response(self):
"""Server runs process_response and overrides the handshake response."""
def test_process_response_modifies_response(self):
"""Server runs process_response and modifies the handshake response."""

def process_response(ws, request, response):
response.headers["X-ProcessResponse-Ran"] = "true"
response.headers["X-ProcessResponse"] = "OK"

with run_server(process_response=process_response) as server:
with run_client(server) as client:
self.assertEqual(
client.response.headers["X-ProcessResponse-Ran"], "true"
)
self.assertEqual(client.response.headers["X-ProcessResponse"], "OK")

def test_process_response_replaces_response(self):
"""Server runs process_response and replaces the handshake response."""

def process_response(ws, request, response):
headers = response.headers.copy()
headers["X-ProcessResponse"] = "OK"
return dataclasses.replace(response, headers=headers)

with run_server(process_response=process_response) as server:
with run_client(server) as client:
self.assertEqual(client.response.headers["X-ProcessResponse"], "OK")

def test_process_response_raises_exception(self):
"""Server returns an error if process_response raises an exception."""
Expand Down

0 comments on commit 8eaa5a2

Please sign in to comment.