Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support traffic.max_concurrency for api server and runner #3864

Merged
merged 7 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions docs/source/concepts/runner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,12 @@ can be specified for the ``nvidia.com/gpu`` key. For example, the following conf

For the detailed information on the meaning of each resource allocation configuration, see :doc:`/guides/scheduling`.

Timeout
^^^^^^^
Traffic Control
^^^^^^^^^^^^^^^

Runner timeout defines the amount of time in seconds to wait before calls a runner is timed out on the API server.
Same as API server, you can also configure the traffic settings for both all runners and individual runner.
Specifcally, ``traffic.timeout`` defines the amount of time in seconds that the runner will wait for a response from the model before timing out.
``traffic.max_concurrency`` defines the maximum number of concurrent requests the runner will accept before returning an error.

.. tab-set::

Expand All @@ -432,7 +434,9 @@ Runner timeout defines the amount of time in seconds to wait before calls a runn
:caption: ⚙️ `configuration.yml`

runners:
timeout: 60
traffic:
timeout: 60
max_concurrency: 10

.. tab-item:: Individual Runner
:sync: individual_runner
Expand All @@ -442,7 +446,9 @@ Runner timeout defines the amount of time in seconds to wait before calls a runn

runners:
iris_clf:
timeout: 60
traffic:
timeout: 60
max_concurrency: 10

Access Logging
^^^^^^^^^^^^^^
Expand Down
23 changes: 22 additions & 1 deletion docs/source/guides/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ The following options are available for the ``api_server`` section:
+=============+=============================================================+=================================================+
| ``workers`` | Number of API workers for to spawn | null [#default_workers]_ |
+-------------+-------------------------------------------------------------+-------------------------------------------------+
| ``timeout`` | Timeout for API server in seconds | 60 |
| ``traffic`` | Traffic control for API server | See :ref:`guides/configuration:\`\`traffic\`\`` |
+-------------+-------------------------------------------------------------+-------------------------------------------------+
| ``backlog`` | Maximum number of connections to hold in backlog | 2048 |
+-------------+-------------------------------------------------------------+-------------------------------------------------+
Expand All @@ -169,6 +169,27 @@ The following options are available for the ``api_server`` section:
| ``tracing`` | Key and values to configure tracing exporter for API server | See :doc:`/guides/tracing` |
+-------------+-------------------------------------------------------------+-------------------------------------------------+

``traffic``
"""""""""""

You can control the traffic of the API server by setting the ``traffic`` field.

To set the maximum number of seconds to wait before a response is received, set ``api_server.traffic.timeout``, the default value is ``60``s:

.. code-block:: yaml

api_server:
traffic:
timeout: 120

To set the maximum number of requests in the process queue across all runners, set ``api_server.traffic.max_concurrency``, the default value is infinite:

.. code-block:: yaml

api_server:
traffic:
max_concurrency: 50

``metrics``
"""""""""""

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ target-version = "py310"
convention = "google"

[tool.ruff.isort]
lines-after-imports = 2
force-single-line = true
known-first-party = ["bentoml"]

[tool.isort]
profile = "black"
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/configuration/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _finalize(self):
"resources",
"logging",
"metrics",
"timeout",
"traffic",
"workers_per_resource",
]
global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
Expand Down
27 changes: 25 additions & 2 deletions src/bentoml/_internal/configuration/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@
}
_API_SERVER_CONFIG = {
"workers": s.Or(s.And(int, ensure_larger_than_zero), None),
"timeout": s.And(int, ensure_larger_than_zero),
frostming marked this conversation as resolved.
Show resolved Hide resolved
s.Optional("traffic"): {
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
"timeout": s.And(int, ensure_larger_than_zero),
"max_concurrency": s.Or(s.And(int, ensure_larger_than_zero), None),
},
"backlog": s.And(int, ensure_larger_than(64)),
"max_runner_connections": s.And(int, ensure_larger_than_zero),
"metrics": {
Expand Down Expand Up @@ -161,7 +164,10 @@
"enabled": bool,
"namespace": str,
},
s.Optional("timeout"): s.And(int, ensure_larger_than_zero),
s.Optional("traffic"): {
"timeout": s.And(int, ensure_larger_than_zero),
"max_concurrency": s.Or(s.And(int, ensure_larger_than_zero), None),
},
}
SCHEMA = s.Schema(
{
Expand Down Expand Up @@ -279,4 +285,21 @@ def migration(*, override_config: dict[str, t.Any]):
current=f"logging.formatting.{f}_format",
replace_with=f"api_server.logging.access.format.{f}",
)
# 7. move timeout to traffic.timeout
for namespace in ("api_server", "runners"):
rename_fields(
override_config,
current=f"{namespace}.timeout",
replace_with=f"{namespace}.traffic.timeout",
)
for key in override_config:
if key.startswith("runners."):
runner_name = key.split(".")[1]
if any(key.schema == runner_name for key in _RUNNER_CONFIG):
continue
rename_fields(
override_config,
current=f"runners.{runner_name}.timeout",
replace_with=f"runners.{runner_name}.traffic.timeout",
)
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
return unflatten(override_config)
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
version: 1
api_server:
workers: ~ # cpu_count() will be used when null
timeout: 60
backlog: 2048
# the maximum number of connections that will be made to any given runner server at once
max_runner_connections: 16
traffic:
timeout: 60
max_concurrency: ~
metrics:
enabled: true
namespace: bentoml_api_server
Expand Down Expand Up @@ -67,7 +69,9 @@ api_server:
runners:
resources: ~
workers_per_resource: 1
timeout: 300
traffic:
timeout: 300
max_concurrency: ~
batching:
enabled: true
max_batch_size: 100
Expand Down
4 changes: 2 additions & 2 deletions src/bentoml/_internal/runner/runner_handle/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def runner_timeout(self) -> int:
"return the configured timeout for this runner."
runner_cfg = BentoMLContainer.runners_config.get()
if self._runner.name in runner_cfg:
return runner_cfg[self._runner.name]["timeout"]
return runner_cfg[self._runner.name].get("traffic", {})["timeout"]
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
else:
return runner_cfg["timeout"]
return runner_cfg.get("traffic", {})["timeout"]

def _close_conn(self) -> None:
if self._conn:
Expand Down
42 changes: 31 additions & 11 deletions src/bentoml/_internal/server/base_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import typing as t
import logging
Expand All @@ -6,12 +8,12 @@

from starlette.responses import PlainTextResponse
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware

if TYPE_CHECKING:
from starlette.routing import BaseRoute
from starlette.requests import Request
from starlette.responses import Response
from starlette.middleware import Middleware
from starlette.applications import Starlette


Expand All @@ -21,41 +23,47 @@
class BaseAppFactory(abc.ABC):
_is_ready: bool = False

def __init__(
self, *, timeout: int | None = None, max_concurrency: int | None = None
) -> None:
self.timeout = timeout
self.max_concurrency = max_concurrency

@property
@abc.abstractmethod
def name(self) -> str:
...

@property
def on_startup(self) -> t.List[t.Callable[[], None]]:
def on_startup(self) -> list[t.Callable[[], None]]:
return [self.mark_as_ready]

@property
def on_shutdown(self) -> t.List[t.Callable[[], None]]:
def on_shutdown(self) -> list[t.Callable[[], None]]:
return []

def mark_as_ready(self) -> None:
self._is_ready = True

async def livez(self, _: "Request") -> "Response":
async def livez(self, _: Request) -> Response:
"""
Health check for BentoML API server.
Make sure it works with Kubernetes liveness probe
"""
return PlainTextResponse("\n", status_code=200)

async def readyz(self, _: "Request") -> "Response":
async def readyz(self, _: Request) -> Response:
if self._is_ready:
return PlainTextResponse("\n", status_code=200)
raise HTTPException(500)

def __call__(self) -> "Starlette":
def __call__(self) -> Starlette:
from starlette.applications import Starlette

from ..configuration import get_debug_mode

@contextlib.asynccontextmanager
async def lifespan(_: "Starlette") -> t.AsyncGenerator[None, None]:
async def lifespan(_: Starlette) -> t.AsyncGenerator[None, None]:
for on_startup in self.on_startup:
on_startup()
yield
Expand All @@ -70,15 +78,27 @@ async def lifespan(_: "Starlette") -> t.AsyncGenerator[None, None]:
)

@property
def routes(self) -> t.List["BaseRoute"]:
def routes(self) -> list[BaseRoute]:
from starlette.routing import Route

routes: t.List["BaseRoute"] = []
routes: list[BaseRoute] = []
routes.append(Route(path="/livez", name="livez", endpoint=self.livez))
routes.append(Route(path="/healthz", name="healthz", endpoint=self.livez))
routes.append(Route(path="/readyz", name="readyz", endpoint=self.readyz))
return routes

@property
def middlewares(self) -> t.List["Middleware"]:
return []
def middlewares(self) -> list[Middleware]:
from .http.traffic import TimeoutMiddleware
from .http.traffic import MaxConcurrencyMiddleware

results: list[Middleware] = []
if self.timeout:
results.append(Middleware(TimeoutMiddleware, timeout=self.timeout))
if self.max_concurrency:
results.append(
Middleware(
MaxConcurrencyMiddleware, max_concurrency=self.max_concurrency
)
)
return results
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import anyio
Expand Down Expand Up @@ -29,3 +30,23 @@ async def __call__(
status_code=504,
)
await resp(scope, receive, send)


class MaxConcurrencyMiddleware:
def __init__(self, app: ext.ASGIApp, max_concurrency: int) -> None:
self.app = app
self._semaphore = asyncio.Semaphore(max_concurrency)

async def __call__(
self, scope: ext.ASGIScope, receive: ext.ASGIReceive, send: ext.ASGISend
) -> None:
if scope["type"] not in ("http", "websocket"):
return await self.app(scope, receive, send)

if self._semaphore.locked():
resp = PlainTextResponse("Too many requests", status_code=429)
await resp(scope, receive, send)
return

async with self._semaphore:
await self.app(scope, receive, send)
9 changes: 4 additions & 5 deletions src/bentoml/_internal/server/http_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,15 @@ def __init__(
enable_metrics: bool = Provide[
BentoMLContainer.api_server_config.metrics.enabled
],
timeout: int = Provide[BentoMLContainer.api_server_config.traffic.timeout],
max_concurrency: int
| None = Provide[BentoMLContainer.api_server_config.traffic.max_concurrency],
):
self.bento_service = bento_service
self.enable_access_control = enable_access_control
self.access_control_options = access_control_options
self.enable_metrics = enable_metrics
super().__init__(timeout=timeout, max_concurrency=max_concurrency)

@property
def name(self) -> str:
Expand Down Expand Up @@ -253,11 +257,6 @@ def client_request_hook(span: Span, _: dict[str, t.Any]) -> None:
)
)

from .http.timeout import TimeoutMiddleware

api_server_timeout = BentoMLContainer.api_server_config.timeout.get()
middlewares.append(Middleware(TimeoutMiddleware, timeout=api_server_timeout))

return middlewares

@property
Expand Down
9 changes: 9 additions & 0 deletions src/bentoml/_internal/server/runner_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def __init__(
self.enable_metrics = enable_metrics

self.dispatchers: dict[str, CorkDispatcher] = {}

runners_config = BentoMLContainer.runners_config.get()
traffic = runners_config.get("traffic", {}).copy()
if runner.name in runners_config:
traffic.update(runners_config[runner.name].get("traffic", {}))
super().__init__(
timeout=traffic["timeout"], max_concurrency=traffic["max_concurrency"]
)

for method in runner.runner_methods:
max_batch_size = method.max_batch_size if method.config.batchable else 1
self.dispatchers[method.name] = CorkDispatcher(
Expand Down
4 changes: 4 additions & 0 deletions tests/e2e/bento_server_http/configs/max_concurrency.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
api_server:
traffic:
timeout: 60
max_concurrency: 2
3 changes: 2 additions & 1 deletion tests/e2e/bento_server_http/configs/timeout.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
api_server:
timeout: 1
traffic:
timeout: 1
6 changes: 6 additions & 0 deletions tests/e2e/bento_server_http/pickle_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import time
import typing as t
from typing import TYPE_CHECKING

Expand All @@ -25,6 +26,11 @@ def echo_json(cls, input_datas: JSONSerializable) -> JSONSerializable:
def echo_obj(cls, input_datas: t.Any) -> t.Any:
return input_datas

def echo_delay(self, input_datas: dict[str, t.Any]) -> JSONSerializable:
delay = input_datas.get("delay", 5)
time.sleep(delay)
return input_datas

def echo_multi_ndarray(self, *input_arr: NDArray[t.Any]) -> tuple[NDArray[t.Any]]:
return input_arr

Expand Down
Loading