Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Expose serve request id from http request/resp #35789

Merged
merged 26 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ py_test(
"tests/test_failure.py",
"tests/test_fastapi.py",
"tests/test_http_adapters.py",
"tests/test_http_headers.py",
"**/conftest.py"]),
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
Expand Down Expand Up @@ -665,3 +666,11 @@ py_test(
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)

py_test(
name = "test_http_headers",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)
4 changes: 4 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ class ServeHandleType(str, Enum):
os.environ.get("RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", "0") == "1"
)

# Request ID used for logging. Can be provided as a request
# header and will always be returned as a response header.
RAY_SERVE_REQUEST_ID_HEADER = "RAY_SERVE_REQUEST_ID"

# Feature flag to enable power of two choices routing.
RAY_SERVE_ENABLE_NEW_ROUTING = (
os.environ.get("RAY_SERVE_ENABLE_NEW_ROUTING", "0") == "1"
Expand Down
30 changes: 26 additions & 4 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import starlette.responses
import starlette.routing
from starlette.types import Message, Receive, Scope, Send
from starlette.datastructures import MutableHeaders
from starlette.middleware import Middleware

import ray
from ray.exceptions import RayActorError, RayTaskError
Expand All @@ -36,6 +38,7 @@
SERVE_NAMESPACE,
DEFAULT_LATENCY_BUCKET_MS,
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING,
RAY_SERVE_REQUEST_ID_HEADER,
RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH,
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
Expand Down Expand Up @@ -458,14 +461,15 @@ async def __call__(self, scope, receive, send):
for key, value in scope.get("headers", []):
if key.decode() == SERVE_MULTIPLEXED_MODEL_ID:
request_context_info["multiplexed_model_id"] = value.decode()
break
if key.decode().upper() == RAY_SERVE_REQUEST_ID_HEADER:
request_context_info["request_id"] = value.decode()
ray.serve.context._serve_request_context.set(
ray.serve.context.RequestContext(**request_context_info)
)

if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING:
status_code = await self.send_request_to_replica_streaming(
request_id, handle, scope, receive, send
request_context_info["request_id"], handle, scope, receive, send
)
else:
status_code = await self.send_request_to_replica_unary(
Expand Down Expand Up @@ -706,6 +710,23 @@ async def send_request_to_replica_streaming(
return status_code


class RequestIdMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
async def send_with_request_id(message: Dict):
request_id = ray.serve.context._serve_request_context.get().request_id
if message["type"] == "http.response.start":
sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved
headers = MutableHeaders(scope=message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this update the underlying scope passed in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure i understand the question, I assume all headers information for http.response.start is under the message directly.
The MutableHeaders will only extract headers from the message https://github.com/encode/starlette/blob/2168e47052239da5df35d5353bb986f760c51cef/starlette/datastructures.py#L534

Copy link
Contributor

@edoakes edoakes Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I'm just confused because you're never modifying the message here, does headers.append(RAY_SERVE_REQUEST_ID, request_id) do it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes. the headers object hold the message header Dict ref. and whenever you update the headers, it will update the message. It is same as you update the message header directly, but this is more standard instead of updating raw dict by yourself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect thanks for explaining

headers.append(RAY_SERVE_REQUEST_ID_HEADER, request_id)
if message["type"] == "websocket.accept":
message[RAY_SERVE_REQUEST_ID_HEADER] = request_id
await send(message)

await self.app(scope, receive, send_with_request_id)


@ray.remote(num_cpus=0)
class HTTPProxyActor:
def __init__(
Expand All @@ -722,9 +743,10 @@ def __init__(
configure_component_logger(
component_name="http_proxy", component_id=node_ip_address
)

if http_middlewares is None:
http_middlewares = []
http_middlewares = [Middleware(RequestIdMiddleware)]
else:
http_middlewares.append(Middleware(RequestIdMiddleware))

if RAY_SERVE_HTTP_PROXY_CALLBACK_IMPORT_PATH:
logger.info(
Expand Down
76 changes: 76 additions & 0 deletions python/ray/serve/tests/test_http_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
import requests
from fastapi import FastAPI
import starlette

import ray
from ray import serve
from ray.serve._private.constants import RAY_SERVE_REQUEST_ID_HEADER


def test_request_id_header_by_default(ray_shutdown):
"""Test that a request_id is generated by default and returned as a header."""

@serve.deployment
class Model:
def __call__(self):
request_id = ray.serve.context._serve_request_context.get().request_id
return request_id

serve.run(Model.bind())
resp = requests.get("http://localhost:8000")
assert resp.status_code == 200
assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers
assert resp.text == resp.headers[RAY_SERVE_REQUEST_ID_HEADER]


@pytest.mark.parametrize("deploy_type", ["basic", "fastapi", "starlette_resp"])
def test_user_provided_request_id_header(ray_shutdown, deploy_type):
"""Test that a user-provided request_id is propagated to the
replica and returned as a header."""

if deploy_type == "fastapi":
app = FastAPI()

@serve.deployment
@serve.ingress(app)
class Model:
@app.get("/")
def say_hi(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return 1

elif deploy_type == "basic":

@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return 1

else:

@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return starlette.responses.Response("1", media_type="application/json")

serve.run(Model.bind())

resp = requests.get(
"http://localhost:8000", headers={RAY_SERVE_REQUEST_ID_HEADER: "123-234"}
)
assert resp.status_code == 200
assert resp.json() == 1
assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers
assert resp.headers[RAY_SERVE_REQUEST_ID_HEADER] == "123-234"

sihanwang41 marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", "-s", __file__]))