Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Expose FastAPI docs path #32863

Merged
merged 12 commits into from
Mar 2, 2023
34 changes: 25 additions & 9 deletions python/ray/serve/_private/application_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.deploy_obj_ref = deploy_obj_ref
self.app_msg = ""
self.route_prefix = None
self.docs_path = None

# This set tracks old deployments that are being deleted
self.deployments_to_delete = set()
Expand Down Expand Up @@ -81,18 +82,30 @@ def deploy(self, deployment_params: List[Dict]) -> List[str]:

# Update route prefix for application
num_route_prefixes = 0
num_docs_paths = 0
for deploy_param in deployment_params:
if (
"route_prefix" in deploy_param
and deploy_param["route_prefix"] is not None
):
if deploy_param.get("route_prefix") is not None:
self.route_prefix = deploy_param["route_prefix"]
num_route_prefixes += 1
assert num_route_prefixes <= 1, (
f"Found multiple route prefix from application {self.name},"
" Please specify only one route prefix for the application "
"to avoid this issue."
)

if deploy_param.get("docs_path") is not None:
self.docs_path = deploy_param["docs_path"]
num_docs_paths += 1
if num_route_prefixes > 1:
raise RayServeException(
f'Found multiple route prefix from application "{self.name}",'
" Please specify only one route prefix for the application "
"to avoid this issue."
)
# NOTE(zcin) This will not catch multiple FastAPI deployments in the application
# if user sets the docs path to None in their FastAPI app.
if num_docs_paths > 1:
raise RayServeException(
f'Found multiple deployments in application "{self.name}" that have '
"a docs path. This may be due to using multiple FastAPI deployments "
"in your application. Please only include one deployment with a docs "
"path in your application to avoid this issue."
)

self.status = ApplicationStatus.DEPLOYING
return cur_deployments_to_delete
Expand Down Expand Up @@ -266,6 +279,9 @@ def get_app_status(self, name: str) -> ApplicationStatusInfo:
)
return self._application_states[name].get_application_status_info()

def get_docs_path(self, app_name: str):
return self._application_states[app_name].docs_path

def list_app_statuses(self) -> Dict[str, ApplicationStatusInfo]:
"""Return a dictionary with {app name: application info}"""
return {
Expand Down
3 changes: 3 additions & 0 deletions python/ray/serve/_private/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def deploy_group(
version=deployment["version"],
route_prefix=deployment["route_prefix"],
is_driver_deployment=deployment["is_driver_deployment"],
docs_path=deployment["docs_path"],
)
)

Expand Down Expand Up @@ -493,6 +494,7 @@ def get_deploy_args(
version: Optional[str] = None,
route_prefix: Optional[str] = None,
is_driver_deployment: Optional[str] = None,
docs_path: Optional[str] = None,
) -> Dict:
"""
Takes a deployment's configuration, and returns the arguments needed
Expand Down Expand Up @@ -548,6 +550,7 @@ def get_deploy_args(
"route_prefix": route_prefix,
"deployer_job_id": ray.get_runtime_context().get_job_id(),
"is_driver_deployment": is_driver_deployment,
"docs_path": docs_path,
}

return controller_deploy_args
Expand Down
5 changes: 4 additions & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ async def __del__(self):
super_cls.__del__()

ASGIAppWrapper.__name__ = cls.__name__
if hasattr(frozen_app, "docs_url"):
ASGIAppWrapper.__fastapi_docs_path__ = frozen_app.docs_url
return ASGIAppWrapper

return decorator
Expand Down Expand Up @@ -408,8 +410,8 @@ def decorator(_func_or_class):
ray_actor_options=(
ray_actor_options if ray_actor_options is not DEFAULT.VALUE else None
),
_internal=True,
is_driver_deployment=is_driver_deployment,
_internal=True,
)

# This handles both parametrized and non-parametrized usage of the
Expand Down Expand Up @@ -548,6 +550,7 @@ def run(
"route_prefix": deployment.route_prefix,
"url": deployment.url,
"is_driver_deployment": deployment._is_driver_deployment,
"docs_path": deployment._docs_path,
}
parameter_group.append(deployment_parameters)
client.deploy_group(
Expand Down
9 changes: 9 additions & 0 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,13 @@ def deploy(
replica_config_proto_bytes: bytes,
route_prefix: Optional[str],
deployer_job_id: Union[str, bytes],
docs_path: Optional[str] = None,
is_driver_deployment: Optional[bool] = False,
) -> bool:
if route_prefix is not None:
assert route_prefix.startswith("/")
if docs_path is not None:
assert docs_path.startswith("/")

deployment_config = DeploymentConfig.from_proto_bytes(
deployment_config_proto_bytes
Expand Down Expand Up @@ -719,6 +722,12 @@ def get_deployment_status(self, name: str) -> Union[None, bytes]:
return None
return status[0].to_proto().SerializeToString()

def get_docs_path(self, name: str):
"""Docs path for application.

Currently, this is the OpenAPI docs path for FastAPI-integrated applications."""
return self.application_state_manager.get_docs_path(name)

def delete_apps(self, names: Iterable[str]):
"""Delete applications based on names

Expand Down
12 changes: 11 additions & 1 deletion python/ray/serve/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(
init_kwargs: Optional[Tuple[Any]] = None,
route_prefix: Union[str, None, DEFAULT] = DEFAULT.VALUE,
ray_actor_options: Optional[Dict] = None,
_internal=False,
is_driver_deployment: Optional[bool] = False,
_internal=False,
) -> None:
"""Construct a Deployment. CONSTRUCTOR SHOULDN'T BE USED DIRECTLY.

Expand Down Expand Up @@ -87,6 +87,15 @@ def __init__(
if init_kwargs is None:
init_kwargs = {}

docs_path = None
if (
inspect.isclass(func_or_class)
and hasattr(func_or_class, "__module__")
and func_or_class.__module__ == "ray.serve.api"
and hasattr(func_or_class, "__fastapi_docs_path__")
):
docs_path = func_or_class.__fastapi_docs_path__

self._func_or_class = func_or_class
self._name = name
self._version = version
Expand All @@ -96,6 +105,7 @@ def __init__(
self._route_prefix = route_prefix
self._ray_actor_options = ray_actor_options
self._is_driver_deployment = is_driver_deployment
self._docs_path = docs_path

@property
def name(self) -> str:
Expand Down
71 changes: 70 additions & 1 deletion python/ray/serve/tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import ray
from ray import serve
from ray.exceptions import GetTimeoutError
from ray.serve.exceptions import RayServeException
from ray.serve._private.client import ServeControllerClient
from ray.serve._private.http_util import make_fastapi_class_based_view
from ray.serve._private.utils import DEFAULT
from ray._private.test_utils import SignalActor
from ray._private.test_utils import SignalActor, wait_for_condition


def test_fastapi_function(serve_instance):
Expand Down Expand Up @@ -664,6 +666,73 @@ def incr(self):
assert resp.json() == [0, 0]


@pytest.mark.parametrize("two_fastapi", [True, False])
def test_two_fastapi_in_one_application(
serve_instance: ServeControllerClient, two_fastapi
):
"""
Check that a deployment graph that would normally work, will not deploy
successfully if there are two FastAPI deployments.
"""
app1 = FastAPI()
app2 = FastAPI()

class SubModel:
def add(self, a: int):
return a + 1

@serve.deployment
@serve.ingress(app1)
class Model:
def __init__(self, submodel):
self.submodel = submodel

@app1.get("/{a}")
async def func(self, a: int):
return await (await self.submodel.add.remote(a))

if two_fastapi:
SubModel = serve.deployment(serve.ingress(app2)(SubModel))
with pytest.raises(RayServeException) as e:
handle = serve.run(Model.bind(SubModel.bind()), name="app1")
assert "FastAPI" in str(e.value)
else:
handle = serve.run(Model.bind(serve.deployment(SubModel).bind()), name="app1")
assert ray.get(handle.func.remote(5)) == 6


@pytest.mark.parametrize(
"is_fastapi,docs_path",
[
(False, None), # Not integrated with FastAPI
(True, "/docs"), # Don't specify docs_url, use default
(True, "/documentation"), # Override default docs url
],
)
def test_fastapi_docs_path(
serve_instance: ServeControllerClient, is_fastapi, docs_path
):
# If not the default docs_url, override it.
if docs_path != "/docs":
app = FastAPI(docs_url=docs_path)
else:
app = FastAPI()

class Model:
@app.get("/{a}")
def func(a: int):
return {"result": a}

if is_fastapi:
Model = serve.ingress(app)(Model)

serve.run(serve.deployment(Model).bind(), name="app1")
wait_for_condition(
lambda: ray.get(serve_instance._controller.get_docs_path.remote("app1"))
== docs_path
)


if __name__ == "__main__":
import sys

Expand Down