Skip to content

Commit

Permalink
feat: support traffic.max_concurrency for api server and runner (#3864)
Browse files Browse the repository at this point in the history
Co-authored-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
frostming and aarnphm authored May 31, 2023
1 parent bd56fa9 commit c724628
Show file tree
Hide file tree
Showing 18 changed files with 185 additions and 43 deletions.
16 changes: 11 additions & 5 deletions docs/source/concepts/runner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,12 @@ can be specified for the ``nvidia.com/gpu`` key. For example, the following conf
For the detailed information on the meaning of each resource allocation configuration, see :doc:`/guides/scheduling`.

Timeout
^^^^^^^
Traffic Control
^^^^^^^^^^^^^^^

Runner timeout defines the amount of time in seconds to wait before calls a runner is timed out on the API server.
Same as API server, you can also configure the traffic settings for both all runners and individual runner.
Specifcally, ``traffic.timeout`` defines the amount of time in seconds that the runner will wait for a response from the model before timing out.
``traffic.max_concurrency`` defines the maximum number of concurrent requests the runner will accept before returning an error.

.. tab-set::

Expand All @@ -432,7 +434,9 @@ Runner timeout defines the amount of time in seconds to wait before calls a runn
:caption: ⚙️ `configuration.yml`
runners:
timeout: 60
traffic:
timeout: 60
max_concurrency: 10
.. tab-item:: Individual Runner
:sync: individual_runner
Expand All @@ -442,7 +446,9 @@ Runner timeout defines the amount of time in seconds to wait before calls a runn
runners:
iris_clf:
timeout: 60
traffic:
timeout: 60
max_concurrency: 10
Access Logging
^^^^^^^^^^^^^^
Expand Down
23 changes: 22 additions & 1 deletion docs/source/guides/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ The following options are available for the ``api_server`` section:
+=============+=============================================================+=================================================+
| ``workers`` | Number of API workers for to spawn | null [#default_workers]_ |
+-------------+-------------------------------------------------------------+-------------------------------------------------+
| ``timeout`` | Timeout for API server in seconds | 60 |
| ``traffic`` | Traffic control for API server | See :ref:`guides/configuration:\`\`traffic\`\`` |
+-------------+-------------------------------------------------------------+-------------------------------------------------+
| ``backlog`` | Maximum number of connections to hold in backlog | 2048 |
+-------------+-------------------------------------------------------------+-------------------------------------------------+
Expand All @@ -169,6 +169,27 @@ The following options are available for the ``api_server`` section:
| ``tracing`` | Key and values to configure tracing exporter for API server | See :doc:`/guides/tracing` |
+-------------+-------------------------------------------------------------+-------------------------------------------------+

``traffic``
"""""""""""

You can control the traffic of the API server by setting the ``traffic`` field.

To set the maximum number of seconds to wait before a response is received, set ``api_server.traffic.timeout``, the default value is ``60``s:
.. code-block:: yaml
api_server:
traffic:
timeout: 120
To set the maximum number of requests in the process queue across all runners, set ``api_server.traffic.max_concurrency``, the default value is infinite:

.. code-block:: yaml
api_server:
traffic:
max_concurrency: 50
``metrics``
"""""""""""

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ target-version = "py310"
convention = "google"

[tool.ruff.isort]
lines-after-imports = 2
force-single-line = true
known-first-party = ["bentoml"]

[tool.isort]
profile = "black"
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/configuration/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _finalize(self):
"resources",
"logging",
"metrics",
"timeout",
"traffic",
"workers_per_resource",
]
global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
Expand Down
27 changes: 25 additions & 2 deletions src/bentoml/_internal/configuration/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@
}
_API_SERVER_CONFIG = {
"workers": s.Or(s.And(int, ensure_larger_than_zero), None),
"timeout": s.And(int, ensure_larger_than_zero),
s.Optional("traffic"): {
"timeout": s.And(int, ensure_larger_than_zero),
"max_concurrency": s.Or(s.And(int, ensure_larger_than_zero), None),
},
"backlog": s.And(int, ensure_larger_than(64)),
"max_runner_connections": s.And(int, ensure_larger_than_zero),
"metrics": {
Expand Down Expand Up @@ -161,7 +164,10 @@
"enabled": bool,
"namespace": str,
},
s.Optional("timeout"): s.And(int, ensure_larger_than_zero),
s.Optional("traffic"): {
"timeout": s.And(int, ensure_larger_than_zero),
"max_concurrency": s.Or(s.And(int, ensure_larger_than_zero), None),
},
}
SCHEMA = s.Schema(
{
Expand Down Expand Up @@ -279,4 +285,21 @@ def migration(*, override_config: dict[str, t.Any]):
current=f"logging.formatting.{f}_format",
replace_with=f"api_server.logging.access.format.{f}",
)
# 7. move timeout to traffic.timeout
for namespace in ("api_server", "runners"):
rename_fields(
override_config,
current=f"{namespace}.timeout",
replace_with=f"{namespace}.traffic.timeout",
)
for key in override_config:
if key.startswith("runners."):
runner_name = key.split(".")[1]
if any(key.schema == runner_name for key in _RUNNER_CONFIG):
continue
rename_fields(
override_config,
current=f"runners.{runner_name}.timeout",
replace_with=f"runners.{runner_name}.traffic.timeout",
)
return unflatten(override_config)
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
version: 1
api_server:
workers: ~ # cpu_count() will be used when null
timeout: 60
backlog: 2048
# the maximum number of connections that will be made to any given runner server at once
max_runner_connections: 16
traffic:
timeout: 60
max_concurrency: ~
metrics:
enabled: true
namespace: bentoml_api_server
Expand Down Expand Up @@ -67,7 +69,9 @@ api_server:
runners:
resources: ~
workers_per_resource: 1
timeout: 300
traffic:
timeout: 300
max_concurrency: ~
batching:
enabled: true
max_batch_size: 100
Expand Down
4 changes: 2 additions & 2 deletions src/bentoml/_internal/runner/runner_handle/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def runner_timeout(self) -> int:
"return the configured timeout for this runner."
runner_cfg = BentoMLContainer.runners_config.get()
if self._runner.name in runner_cfg:
return runner_cfg[self._runner.name]["timeout"]
return runner_cfg[self._runner.name].get("traffic", {})["timeout"]
else:
return runner_cfg["timeout"]
return runner_cfg.get("traffic", {})["timeout"]

def _close_conn(self) -> None:
if self._conn:
Expand Down
42 changes: 31 additions & 11 deletions src/bentoml/_internal/server/base_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import typing as t
import logging
Expand All @@ -6,12 +8,12 @@

from starlette.responses import PlainTextResponse
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware

if TYPE_CHECKING:
from starlette.routing import BaseRoute
from starlette.requests import Request
from starlette.responses import Response
from starlette.middleware import Middleware
from starlette.applications import Starlette


Expand All @@ -21,41 +23,47 @@
class BaseAppFactory(abc.ABC):
_is_ready: bool = False

def __init__(
self, *, timeout: int | None = None, max_concurrency: int | None = None
) -> None:
self.timeout = timeout
self.max_concurrency = max_concurrency

@property
@abc.abstractmethod
def name(self) -> str:
...

@property
def on_startup(self) -> t.List[t.Callable[[], None]]:
def on_startup(self) -> list[t.Callable[[], None]]:
return [self.mark_as_ready]

@property
def on_shutdown(self) -> t.List[t.Callable[[], None]]:
def on_shutdown(self) -> list[t.Callable[[], None]]:
return []

def mark_as_ready(self) -> None:
self._is_ready = True

async def livez(self, _: "Request") -> "Response":
async def livez(self, _: Request) -> Response:
"""
Health check for BentoML API server.
Make sure it works with Kubernetes liveness probe
"""
return PlainTextResponse("\n", status_code=200)

async def readyz(self, _: "Request") -> "Response":
async def readyz(self, _: Request) -> Response:
if self._is_ready:
return PlainTextResponse("\n", status_code=200)
raise HTTPException(500)

def __call__(self) -> "Starlette":
def __call__(self) -> Starlette:
from starlette.applications import Starlette

from ..configuration import get_debug_mode

@contextlib.asynccontextmanager
async def lifespan(_: "Starlette") -> t.AsyncGenerator[None, None]:
async def lifespan(_: Starlette) -> t.AsyncGenerator[None, None]:
for on_startup in self.on_startup:
on_startup()
yield
Expand All @@ -70,15 +78,27 @@ async def lifespan(_: "Starlette") -> t.AsyncGenerator[None, None]:
)

@property
def routes(self) -> t.List["BaseRoute"]:
def routes(self) -> list[BaseRoute]:
from starlette.routing import Route

routes: t.List["BaseRoute"] = []
routes: list[BaseRoute] = []
routes.append(Route(path="/livez", name="livez", endpoint=self.livez))
routes.append(Route(path="/healthz", name="healthz", endpoint=self.livez))
routes.append(Route(path="/readyz", name="readyz", endpoint=self.readyz))
return routes

@property
def middlewares(self) -> t.List["Middleware"]:
return []
def middlewares(self) -> list[Middleware]:
from .http.traffic import TimeoutMiddleware
from .http.traffic import MaxConcurrencyMiddleware

results: list[Middleware] = []
if self.timeout:
results.append(Middleware(TimeoutMiddleware, timeout=self.timeout))
if self.max_concurrency:
results.append(
Middleware(
MaxConcurrencyMiddleware, max_concurrency=self.max_concurrency
)
)
return results
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import anyio
Expand Down Expand Up @@ -29,3 +30,23 @@ async def __call__(
status_code=504,
)
await resp(scope, receive, send)


class MaxConcurrencyMiddleware:
def __init__(self, app: ext.ASGIApp, max_concurrency: int) -> None:
self.app = app
self._semaphore = asyncio.Semaphore(max_concurrency)

async def __call__(
self, scope: ext.ASGIScope, receive: ext.ASGIReceive, send: ext.ASGISend
) -> None:
if scope["type"] not in ("http", "websocket"):
return await self.app(scope, receive, send)

if self._semaphore.locked():
resp = PlainTextResponse("Too many requests", status_code=429)
await resp(scope, receive, send)
return

async with self._semaphore:
await self.app(scope, receive, send)
9 changes: 4 additions & 5 deletions src/bentoml/_internal/server/http_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,15 @@ def __init__(
enable_metrics: bool = Provide[
BentoMLContainer.api_server_config.metrics.enabled
],
timeout: int = Provide[BentoMLContainer.api_server_config.traffic.timeout],
max_concurrency: int
| None = Provide[BentoMLContainer.api_server_config.traffic.max_concurrency],
):
self.bento_service = bento_service
self.enable_access_control = enable_access_control
self.access_control_options = access_control_options
self.enable_metrics = enable_metrics
super().__init__(timeout=timeout, max_concurrency=max_concurrency)

@property
def name(self) -> str:
Expand Down Expand Up @@ -253,11 +257,6 @@ def client_request_hook(span: Span, _: dict[str, t.Any]) -> None:
)
)

from .http.timeout import TimeoutMiddleware

api_server_timeout = BentoMLContainer.api_server_config.timeout.get()
middlewares.append(Middleware(TimeoutMiddleware, timeout=api_server_timeout))

return middlewares

@property
Expand Down
9 changes: 9 additions & 0 deletions src/bentoml/_internal/server/runner_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def __init__(
self.enable_metrics = enable_metrics

self.dispatchers: dict[str, CorkDispatcher] = {}

runners_config = BentoMLContainer.runners_config.get()
traffic = runners_config.get("traffic", {}).copy()
if runner.name in runners_config:
traffic.update(runners_config[runner.name].get("traffic", {}))
super().__init__(
timeout=traffic["timeout"], max_concurrency=traffic["max_concurrency"]
)

for method in runner.runner_methods:
max_batch_size = method.max_batch_size if method.config.batchable else 1
self.dispatchers[method.name] = CorkDispatcher(
Expand Down
4 changes: 4 additions & 0 deletions tests/e2e/bento_server_http/configs/max_concurrency.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
api_server:
traffic:
timeout: 60
max_concurrency: 2
3 changes: 2 additions & 1 deletion tests/e2e/bento_server_http/configs/timeout.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
api_server:
timeout: 1
traffic:
timeout: 1
6 changes: 6 additions & 0 deletions tests/e2e/bento_server_http/pickle_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import time
import typing as t
from typing import TYPE_CHECKING

Expand All @@ -25,6 +26,11 @@ def echo_json(cls, input_datas: JSONSerializable) -> JSONSerializable:
def echo_obj(cls, input_datas: t.Any) -> t.Any:
return input_datas

def echo_delay(self, input_datas: dict[str, t.Any]) -> JSONSerializable:
delay = input_datas.get("delay", 5)
time.sleep(delay)
return input_datas

def echo_multi_ndarray(self, *input_arr: NDArray[t.Any]) -> tuple[NDArray[t.Any]]:
return input_arr

Expand Down
Loading

0 comments on commit c724628

Please sign in to comment.