Skip to content

Commit

Permalink
[serve] Add FF to run sync methods in a threadpool (#48897)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Adds a feature flag to run sync user-defined methods in a threadpool by
default. This matches the existing behavior when using a FastAPI
ingress.

This should address a lot of user confusion and make it easier to write
performant code by default. For example, just sticking a torch model
call in a sync method will now provide reasonable performance out of the
box.

However, there may be some existing user code that is not thread safe,
so we need to do a gentle migration. This PR introduces the behavior
behind a feature flag and warns users about the upcoming change and how
to opt into the new behavior or maintain existing behavior once it does
(just adding `async def` will do it).

I've opted to set the max thread pool size to `max_ongoing_requests`,
which seems like a reasonable policy. If needed we can add a user-facing
API for this in the future.

TODO before merging:

- [x] Get it working for sync generators.
- [x] Add warning for default change (people can keep behavior by
changing to async def).
- [x] Add/update UserCallableWrapper tests.
- [x] Add/update some integration tests (verify that request context is
set correctly!).
- [x] Set maximum thread pool size.

## Related issue number

Closes #44354
Closes #44403
Closes #48903

---------

Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
edoakes authored Nov 26, 2024
1 parent 49e3061 commit 0f2c62c
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 41 deletions.
14 changes: 14 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,17 @@
RAY_SERVE_FORCE_LOCAL_TESTING_MODE = (
os.environ.get("RAY_SERVE_FORCE_LOCAL_TESTING_MODE", "0") == "1"
)

# Run sync methods defined in the replica in a thread pool by default.
RAY_SERVE_RUN_SYNC_IN_THREADPOOL = (
os.environ.get("RAY_SERVE_RUN_SYNC_IN_THREADPOOL", "0") == "1"
)

RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING = (
"Calling sync method '{method_name}' directly on the "
"asyncio loop. In a future version, sync methods will be run in a "
"threadpool by default. Ensure your sync methods are thread safe "
"or keep the existing behavior by making them `async def`. Opt "
"into the new behavior by setting "
"RAY_SERVE_RUN_SYNC_IN_THREADPOOL=1."
)
6 changes: 5 additions & 1 deletion python/ray/serve/_private/local_testing_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import ray
from ray import cloudpickle
from ray.serve._private.common import DeploymentID, RequestMetadata
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve._private.constants import (
RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
SERVE_LOGGER_NAME,
)
from ray.serve._private.replica import UserCallableWrapper
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.router import Router
Expand Down Expand Up @@ -66,6 +69,7 @@ def make_local_deployment_handle(
deployment.init_args,
deployment.init_kwargs,
deployment_id=deployment_id,
run_sync_methods_in_threadpool=RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
)
try:
logger.info(f"Initializing local replica class for {deployment_id}.")
Expand Down
145 changes: 127 additions & 18 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import concurrent.futures
import functools
import inspect
import logging
import os
import pickle
import threading
import time
import traceback
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import wraps
from importlib import import_module
from typing import (
Any,
Expand All @@ -23,6 +24,7 @@
)

import starlette.responses
from anyio import to_thread
from starlette.types import ASGIApp, Message

import ray
Expand All @@ -47,6 +49,8 @@
HEALTH_CHECK_METHOD,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S,
RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING,
RECONFIGURE_METHOD,
SERVE_CONTROLLER_NAME,
SERVE_LOGGER_NAME,
Expand Down Expand Up @@ -274,6 +278,7 @@ def __init__(
init_args,
init_kwargs,
deployment_id=self._deployment_id,
run_sync_methods_in_threadpool=RAY_SERVE_RUN_SYNC_IN_THREADPOOL,
)

# Guards against calling the user's callable constructor multiple times.
Expand Down Expand Up @@ -602,6 +607,11 @@ async def initialize(self, deployment_config: DeploymentConfig):
self._user_callable_initialized = True

if deployment_config:
await asyncio.wrap_future(
self._user_callable_wrapper.set_sync_method_threadpool_limit(
deployment_config.max_ongoing_requests
)
)
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
deployment_config.user_config
Expand Down Expand Up @@ -635,6 +645,11 @@ async def reconfigure(self, deployment_config: DeploymentConfig):
if logging_config_changed:
self._configure_logger_and_profilers(deployment_config.logging_config)

await asyncio.wrap_future(
self._user_callable_wrapper.set_sync_method_threadpool_limit(
deployment_config.max_ongoing_requests
)
)
if user_config_changed:
await asyncio.wrap_future(
self._user_callable_wrapper.call_reconfigure(
Expand Down Expand Up @@ -990,6 +1005,7 @@ def __init__(
init_kwargs: Dict,
*,
deployment_id: DeploymentID,
run_sync_methods_in_threadpool: bool,
):
if not (inspect.isfunction(deployment_def) or inspect.isclass(deployment_def)):
raise TypeError(
Expand All @@ -1003,6 +1019,8 @@ def __init__(
self._is_function = inspect.isfunction(deployment_def)
self._deployment_id = deployment_id
self._destructor_called = False
self._run_sync_methods_in_threadpool = run_sync_methods_in_threadpool
self._warned_about_sync_method_change = False

# Will be populated in `initialize_callable`.
self._callable = None
Expand Down Expand Up @@ -1033,7 +1051,7 @@ def _run_on_user_code_event_loop(f: Callable) -> Callable:
f
), "_run_on_user_code_event_loop can only be used on coroutine functions."

@wraps(f)
@functools.wraps(f)
def wrapper(self, *args, **kwargs) -> concurrent.futures.Future:
return asyncio.run_coroutine_threadsafe(
f(self, *args, **kwargs),
Expand All @@ -1042,6 +1060,12 @@ def wrapper(self, *args, **kwargs) -> concurrent.futures.Future:

return wrapper

@_run_on_user_code_event_loop
async def set_sync_method_threadpool_limit(self, limit: int):
# NOTE(edoakes): the limit is thread local, so this must
# be run on the user code event loop.
to_thread.current_default_thread_limiter().total_tokens = limit

def _get_user_callable_method(self, method_name: str) -> Callable:
if self._is_function:
return self._callable
Expand Down Expand Up @@ -1082,17 +1106,89 @@ async def _send_user_result_over_asgi(
else:
await Response(result).send(scope, receive, send)

async def _call_func_or_gen(self, callable: Callable, *args, **kwargs) -> Any:
async def _call_func_or_gen(
self,
callable: Callable,
*,
args: Optional[Tuple[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
request_metadata: Optional[RequestMetadata] = None,
generator_result_callback: Optional[Callable] = None,
run_sync_methods_in_threadpool_override: Optional[bool] = None,
) -> Tuple[Any, bool]:
"""Call the callable with the provided arguments.
This is a convenience wrapper that will work for `def`, `async def`,
generator, and async generator functions.
Returns the result and a boolean indicating if the result was a sync generator
that has already been consumed.
"""
result = callable(*args, **kwargs)
if inspect.iscoroutine(result):
result = await result
sync_gen_consumed = False
args = args if args is not None else tuple()
kwargs = kwargs if kwargs is not None else dict()
run_sync_in_threadpool = (
self._run_sync_methods_in_threadpool
if run_sync_methods_in_threadpool_override is None
else run_sync_methods_in_threadpool_override
)
is_sync_method = (
inspect.isfunction(callable) or inspect.ismethod(callable)
) and not (
inspect.iscoroutinefunction(callable)
or inspect.isasyncgenfunction(callable)
)

return result
if is_sync_method and run_sync_in_threadpool:
is_generator = inspect.isgeneratorfunction(callable)
if is_generator:
sync_gen_consumed = True
if request_metadata and not request_metadata.is_streaming:
# TODO(edoakes): make this check less redundant with the one in
# _handle_user_method_result.
raise TypeError(
f"Method '{callable.__name__}' returned a generator. "
"You must use `handle.options(stream=True)` to call "
"generators on a deployment."
)

def run_callable():
result = callable(*args, **kwargs)
if is_generator:
for r in result:
# TODO(edoakes): make this less redundant with the handling in
# _handle_user_method_result.
if request_metadata and request_metadata.is_grpc_request:
r = (request_metadata.grpc_context, r.SerializeToString())
generator_result_callback(r)

result = None

return result

# NOTE(edoakes): we use anyio.to_thread here because it's what Starlette
# uses (and therefore FastAPI too). The max size of the threadpool is
# set to max_ongoing_requests in the replica wrapper.
# anyio.to_thread propagates ContextVars to the worker thread automatically.
result = await to_thread.run_sync(run_callable)
else:
if (
is_sync_method
and not self._warned_about_sync_method_change
and run_sync_methods_in_threadpool_override is None
):
self._warned_about_sync_method_change = True
warnings.warn(
RAY_SERVE_RUN_SYNC_IN_THREADPOOL_WARNING.format(
method_name=callable.__name__,
)
)

result = callable(*args, **kwargs)
if inspect.iscoroutine(result):
result = await result

return result, sync_gen_consumed

@property
def user_callable(self) -> Optional[Callable]:
Expand Down Expand Up @@ -1129,8 +1225,10 @@ async def initialize_callable(self) -> Optional[ASGIApp]:
self._callable = self._deployment_def.__new__(self._deployment_def)
await self._call_func_or_gen(
self._callable.__init__,
*self._init_args,
**self._init_kwargs,
args=self._init_args,
kwargs=self._init_kwargs,
# Always run the constructor on the main user code thread.
run_sync_methods_in_threadpool_override=False,
)

if isinstance(self._callable, ASGIAppReplicaWrapper):
Expand Down Expand Up @@ -1192,7 +1290,7 @@ async def call_reconfigure(self, user_config: Any):
)
await self._call_func_or_gen(
getattr(self._callable, RECONFIGURE_METHOD),
user_config,
args=(user_config,),
)

def _prepare_args_for_http_request(
Expand Down Expand Up @@ -1264,6 +1362,7 @@ async def _handle_user_method_result(
user_method_name: str,
request_metadata: RequestMetadata,
*,
sync_gen_consumed: bool,
generator_result_callback: Optional[Callable],
is_asgi_app: bool,
asgi_args: Optional[ASGIArgs],
Expand Down Expand Up @@ -1297,7 +1396,7 @@ async def _handle_user_method_result(
# For the FastAPI codepath, the response has already been sent over
# ASGI, but for the vanilla deployment codepath we need to send it.
await self._send_user_result_over_asgi(result, asgi_args)
elif not request_metadata.is_http_request:
elif not request_metadata.is_http_request and not sync_gen_consumed:
# If a unary method is called with stream=True for anything EXCEPT
# an HTTP request, raise an error.
# HTTP requests are always streaming regardless of if the method
Expand Down Expand Up @@ -1382,12 +1481,20 @@ async def call_user_method(
request_args[0], request_metadata, user_method_params
)

result = await self._handle_user_method_result(
await self._call_func_or_gen(
user_method, *request_args, **request_kwargs
),
result, sync_gen_consumed = await self._call_func_or_gen(
user_method,
args=request_args,
kwargs=request_kwargs,
request_metadata=request_metadata,
generator_result_callback=generator_result_callback
if request_metadata.is_streaming
else None,
)
return await self._handle_user_method_result(
result,
user_method_name,
request_metadata,
sync_gen_consumed=sync_gen_consumed,
generator_result_callback=generator_result_callback,
is_asgi_app=is_asgi_app,
asgi_args=asgi_args,
Expand All @@ -1412,8 +1519,6 @@ async def call_user_method(
if receive_task is not None and not receive_task.done():
receive_task.cancel()

return result

@_run_on_user_code_event_loop
async def call_destructor(self):
"""Explicitly call the `__del__` method of the user callable.
Expand All @@ -1437,7 +1542,11 @@ async def call_destructor(self):
try:
if hasattr(self._callable, "__del__"):
# Make sure to accept `async def __del__(self)` as well.
await self._call_func_or_gen(self._callable.__del__)
await self._call_func_or_gen(
self._callable.__del__,
# Always run the destructor on the main user callable thread.
run_sync_methods_in_threadpool_override=False,
)

if hasattr(self._callable, "__serve_multiplex_wrapper"):
await getattr(self._callable, "__serve_multiplex_wrapper").shutdown()
Expand Down
22 changes: 22 additions & 0 deletions python/ray/serve/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,25 @@ py_test_module_list(
"//python/ray/serve:serve_lib",
],
)


# Test currently off-by-default behavior to run replica sync methods in a threadpool.
# TODO(edoakes): remove this once the FF is flipped on by default.
py_test_module_list(
size = "small",
env = {"RAY_SERVE_RUN_SYNC_IN_THREADPOOL": "1"},
files = [
"test_replica_sync_methods.py",
],
name_suffix = "_with_run_sync_in_threadpool",
tags = [
"exclusive",
"no_windows",
"team:serve",
],
deps = [
":common",
":conftest",
"//python/ray/serve:serve_lib",
],
)
Loading

0 comments on commit 0f2c62c

Please sign in to comment.