diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 8bbe33ac7a0f..b2b4e2558966 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -503,6 +503,7 @@ py_test( "tests/test_failure.py", "tests/test_fastapi.py", "tests/test_http_adapters.py", + "tests/test_http_headers.py", "**/conftest.py"]), tags = ["exclusive", "team:serve"], deps = [":serve_lib"], @@ -665,3 +666,11 @@ py_test( tags = ["exclusive", "team:serve"], deps = [":serve_lib"], ) + +py_test( + name = "test_http_headers", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) diff --git a/python/ray/serve/_private/constants.py b/python/ray/serve/_private/constants.py index a7a38c1a7db5..05b0a615df37 100644 --- a/python/ray/serve/_private/constants.py +++ b/python/ray/serve/_private/constants.py @@ -208,6 +208,10 @@ class ServeHandleType(str, Enum): os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") == "1" ) +# Request ID used for logging. Can be provided as a request +# header and will always be returned as a response header. +RAY_SERVE_REQUEST_ID_HEADER = "RAY_SERVE_REQUEST_ID" + # Feature flag to enable power of two choices routing. RAY_SERVE_ENABLE_NEW_ROUTING = ( os.environ.get("RAY_SERVE_ENABLE_NEW_ROUTING", "0") == "1" diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 2d67b8115478..3b50bc4a276e 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -12,6 +12,8 @@ import starlette.responses import starlette.routing from starlette.types import Message, Receive, Scope, Send +from starlette.datastructures import MutableHeaders +from starlette.middleware import Middleware import ray from ray.exceptions import RayActorError, RayTaskError @@ -36,6 +38,7 @@ SERVE_NAMESPACE, DEFAULT_LATENCY_BUCKET_MS, RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + RAY_SERVE_REQUEST_ID_HEADER, RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH, ) from ray.serve._private.long_poll import LongPollClient, LongPollNamespace @@ -458,14 +461,15 @@ async def __call__(self, scope, receive, send): for key, value in scope.get("headers", []): if key.decode() == SERVE_MULTIPLEXED_MODEL_ID: request_context_info["multiplexed_model_id"] = value.decode() - break + if key.decode().upper() == RAY_SERVE_REQUEST_ID_HEADER: + request_context_info["request_id"] = value.decode() ray.serve.context._serve_request_context.set( ray.serve.context.RequestContext(**request_context_info) ) if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING: status_code = await self.send_request_to_replica_streaming( - request_id, handle, scope, receive, send + request_context_info["request_id"], handle, scope, receive, send ) else: status_code = await self.send_request_to_replica_unary( @@ -706,6 +710,23 @@ async def send_request_to_replica_streaming( return status_code +class RequestIdMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + async def send_with_request_id(message: Dict): + request_id = ray.serve.context._serve_request_context.get().request_id + if message["type"] == "http.response.start": + headers = MutableHeaders(scope=message) + headers.append(RAY_SERVE_REQUEST_ID_HEADER, request_id) + if message["type"] == "websocket.accept": + message[RAY_SERVE_REQUEST_ID_HEADER] = request_id + await send(message) + + await self.app(scope, receive, send_with_request_id) + + @ray.remote(num_cpus=0) class HTTPProxyActor: def __init__( @@ -722,9 +743,10 @@ def __init__( configure_component_logger( component_name="http_proxy", component_id=node_ip_address ) - if http_middlewares is None: - http_middlewares = [] + http_middlewares = [Middleware(RequestIdMiddleware)] + else: + http_middlewares.append(Middleware(RequestIdMiddleware)) if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH: logger.info( diff --git a/python/ray/serve/tests/test_http_headers.py b/python/ray/serve/tests/test_http_headers.py new file mode 100644 index 000000000000..15f293373e59 --- /dev/null +++ b/python/ray/serve/tests/test_http_headers.py @@ -0,0 +1,76 @@ +import pytest +import requests +from fastapi import FastAPI +import starlette + +import ray +from ray import serve +from ray.serve._private.constants import RAY_SERVE_REQUEST_ID_HEADER + + +def test_request_id_header_by_default(ray_shutdown): + """Test that a request_id is generated by default and returned as a header.""" + + @serve.deployment + class Model: + def __call__(self): + request_id = ray.serve.context._serve_request_context.get().request_id + return request_id + + serve.run(Model.bind()) + resp = requests.get("http://localhost:8000") + assert resp.status_code == 200 + assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers + assert resp.text == resp.headers[RAY_SERVE_REQUEST_ID_HEADER] + + +@pytest.mark.parametrize("deploy_type", ["basic", "fastapi", "starlette_resp"]) +def test_user_provided_request_id_header(ray_shutdown, deploy_type): + """Test that a user-provided request_id is propagated to the + replica and returned as a header.""" + + if deploy_type == "fastapi": + app = FastAPI() + + @serve.deployment + @serve.ingress(app) + class Model: + @app.get("/") + def say_hi(self) -> int: + request_id = ray.serve.context._serve_request_context.get().request_id + assert request_id == "123-234" + return 1 + + elif deploy_type == "basic": + + @serve.deployment + class Model: + def __call__(self) -> int: + request_id = ray.serve.context._serve_request_context.get().request_id + assert request_id == "123-234" + return 1 + + else: + + @serve.deployment + class Model: + def __call__(self) -> int: + request_id = ray.serve.context._serve_request_context.get().request_id + assert request_id == "123-234" + return starlette.responses.Response("1", media_type="application/json") + + serve.run(Model.bind()) + + resp = requests.get( + "http://localhost:8000", headers={RAY_SERVE_REQUEST_ID_HEADER: "123-234"} + ) + assert resp.status_code == 200 + assert resp.json() == 1 + assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers + assert resp.headers[RAY_SERVE_REQUEST_ID_HEADER] == "123-234" + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-s", __file__]))