Skip to content

Commit

Permalink
[Serve] Expose serve request id from http request/resp (ray-project#3…
Browse files Browse the repository at this point in the history
…5789)

User can inject request id:
```
@serve.deployment
class Model:
    def __call__(self) -> int:
        return 1

serve.run(Model.bind())
resp = requests.get("http://localhost:8000", headers={"RAY_SERVE_REQUEST_ID": "123-234"})
```

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
sihanwang41 authored and arvind-chandra committed Aug 31, 2023
1 parent 890880c commit 8b0e067
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 4 deletions.
9 changes: 9 additions & 0 deletions python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ py_test(
"tests/test_failure.py",
"tests/test_fastapi.py",
"tests/test_http_adapters.py",
"tests/test_http_headers.py",
"**/conftest.py"]),
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
Expand Down Expand Up @@ -665,3 +666,11 @@ py_test(
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)

py_test(
name = "test_http_headers",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)
4 changes: 4 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ class ServeHandleType(str, Enum):
os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") == "1"
)

# Request ID used for logging. Can be provided as a request
# header and will always be returned as a response header.
RAY_SERVE_REQUEST_ID_HEADER = "RAY_SERVE_REQUEST_ID"

# Feature flag to enable power of two choices routing.
RAY_SERVE_ENABLE_NEW_ROUTING = (
os.environ.get("RAY_SERVE_ENABLE_NEW_ROUTING", "0") == "1"
Expand Down
30 changes: 26 additions & 4 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import starlette.responses
import starlette.routing
from starlette.types import Message, Receive, Scope, Send
from starlette.datastructures import MutableHeaders
from starlette.middleware import Middleware

import ray
from ray.exceptions import RayActorError, RayTaskError
Expand All @@ -36,6 +38,7 @@
SERVE_NAMESPACE,
DEFAULT_LATENCY_BUCKET_MS,
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
RAY_SERVE_REQUEST_ID_HEADER,
RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH,
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
Expand Down Expand Up @@ -458,14 +461,15 @@ async def __call__(self, scope, receive, send):
for key, value in scope.get("headers", []):
if key.decode() == SERVE_MULTIPLEXED_MODEL_ID:
request_context_info["multiplexed_model_id"] = value.decode()
break
if key.decode().upper() == RAY_SERVE_REQUEST_ID_HEADER:
request_context_info["request_id"] = value.decode()
ray.serve.context._serve_request_context.set(
ray.serve.context.RequestContext(**request_context_info)
)

if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING:
status_code = await self.send_request_to_replica_streaming(
request_id, handle, scope, receive, send
request_context_info["request_id"], handle, scope, receive, send
)
else:
status_code = await self.send_request_to_replica_unary(
Expand Down Expand Up @@ -706,6 +710,23 @@ async def send_request_to_replica_streaming(
return status_code


class RequestIdMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
async def send_with_request_id(message: Dict):
request_id = ray.serve.context._serve_request_context.get().request_id
if message["type"] == "http.response.start":
headers = MutableHeaders(scope=message)
headers.append(RAY_SERVE_REQUEST_ID_HEADER, request_id)
if message["type"] == "websocket.accept":
message[RAY_SERVE_REQUEST_ID_HEADER] = request_id
await send(message)

await self.app(scope, receive, send_with_request_id)


@ray.remote(num_cpus=0)
class HTTPProxyActor:
def __init__(
Expand All @@ -722,9 +743,10 @@ def __init__(
configure_component_logger(
component_name="http_proxy", component_id=node_ip_address
)

if http_middlewares is None:
http_middlewares = []
http_middlewares = [Middleware(RequestIdMiddleware)]
else:
http_middlewares.append(Middleware(RequestIdMiddleware))

if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH:
logger.info(
Expand Down
76 changes: 76 additions & 0 deletions python/ray/serve/tests/test_http_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
import requests
from fastapi import FastAPI
import starlette

import ray
from ray import serve
from ray.serve._private.constants import RAY_SERVE_REQUEST_ID_HEADER


def test_request_id_header_by_default(ray_shutdown):
"""Test that a request_id is generated by default and returned as a header."""

@serve.deployment
class Model:
def __call__(self):
request_id = ray.serve.context._serve_request_context.get().request_id
return request_id

serve.run(Model.bind())
resp = requests.get("http://localhost:8000")
assert resp.status_code == 200
assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers
assert resp.text == resp.headers[RAY_SERVE_REQUEST_ID_HEADER]


@pytest.mark.parametrize("deploy_type", ["basic", "fastapi", "starlette_resp"])
def test_user_provided_request_id_header(ray_shutdown, deploy_type):
"""Test that a user-provided request_id is propagated to the
replica and returned as a header."""

if deploy_type == "fastapi":
app = FastAPI()

@serve.deployment
@serve.ingress(app)
class Model:
@app.get("/")
def say_hi(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return 1

elif deploy_type == "basic":

@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return 1

else:

@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return starlette.responses.Response("1", media_type="application/json")

serve.run(Model.bind())

resp = requests.get(
"http://localhost:8000", headers={RAY_SERVE_REQUEST_ID_HEADER: "123-234"}
)
assert resp.status_code == 200
assert resp.json() == 1
assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers
assert resp.headers[RAY_SERVE_REQUEST_ID_HEADER] == "123-234"


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", "-s", __file__]))

0 comments on commit 8b0e067

Please sign in to comment.