Skip to content

Commit

Permalink
[Serve] Expose FastAPI docs path (#32863)
Browse files Browse the repository at this point in the history
For FastAPI integrated applications, we want to expose the Open api docs path.
  • Loading branch information
zcin authored Mar 2, 2023
1 parent 4d0ce8d commit fdf9866
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 12 deletions.
34 changes: 25 additions & 9 deletions python/ray/serve/_private/application_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.deploy_obj_ref = deploy_obj_ref
self.app_msg = ""
self.route_prefix = None
self.docs_path = None

# This set tracks old deployments that are being deleted
self.deployments_to_delete = set()
Expand Down Expand Up @@ -81,18 +82,30 @@ def deploy(self, deployment_params: List[Dict]) -> List[str]:

# Update route prefix for application
num_route_prefixes = 0
num_docs_paths = 0
for deploy_param in deployment_params:
if (
"route_prefix" in deploy_param
and deploy_param["route_prefix"] is not None
):
if deploy_param.get("route_prefix") is not None:
self.route_prefix = deploy_param["route_prefix"]
num_route_prefixes += 1
assert num_route_prefixes <= 1, (
f"Found multiple route prefix from application {self.name},"
" Please specify only one route prefix for the application "
"to avoid this issue."
)

if deploy_param.get("docs_path") is not None:
self.docs_path = deploy_param["docs_path"]
num_docs_paths += 1
if num_route_prefixes > 1:
raise RayServeException(
f'Found multiple route prefix from application "{self.name}",'
" Please specify only one route prefix for the application "
"to avoid this issue."
)
# NOTE(zcin) This will not catch multiple FastAPI deployments in the application
# if user sets the docs path to None in their FastAPI app.
if num_docs_paths > 1:
raise RayServeException(
f'Found multiple deployments in application "{self.name}" that have '
"a docs path. This may be due to using multiple FastAPI deployments "
"in your application. Please only include one deployment with a docs "
"path in your application to avoid this issue."
)

self.status = ApplicationStatus.DEPLOYING
return cur_deployments_to_delete
Expand Down Expand Up @@ -266,6 +279,9 @@ def get_app_status(self, name: str) -> ApplicationStatusInfo:
)
return self._application_states[name].get_application_status_info()

def get_docs_path(self, app_name: str):
return self._application_states[app_name].docs_path

def list_app_statuses(self) -> Dict[str, ApplicationStatusInfo]:
"""Return a dictionary with {app name: application info}"""
return {
Expand Down
3 changes: 3 additions & 0 deletions python/ray/serve/_private/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def deploy_group(
version=deployment["version"],
route_prefix=deployment["route_prefix"],
is_driver_deployment=deployment["is_driver_deployment"],
docs_path=deployment["docs_path"],
)
)

Expand Down Expand Up @@ -493,6 +494,7 @@ def get_deploy_args(
version: Optional[str] = None,
route_prefix: Optional[str] = None,
is_driver_deployment: Optional[str] = None,
docs_path: Optional[str] = None,
) -> Dict:
"""
Takes a deployment's configuration, and returns the arguments needed
Expand Down Expand Up @@ -548,6 +550,7 @@ def get_deploy_args(
"route_prefix": route_prefix,
"deployer_job_id": ray.get_runtime_context().get_job_id(),
"is_driver_deployment": is_driver_deployment,
"docs_path": docs_path,
}

return controller_deploy_args
Expand Down
5 changes: 4 additions & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ async def __del__(self):
super_cls.__del__()

ASGIAppWrapper.__name__ = cls.__name__
if hasattr(frozen_app, "docs_url"):
ASGIAppWrapper.__fastapi_docs_path__ = frozen_app.docs_url
return ASGIAppWrapper

return decorator
Expand Down Expand Up @@ -408,8 +410,8 @@ def decorator(_func_or_class):
ray_actor_options=(
ray_actor_options if ray_actor_options is not DEFAULT.VALUE else None
),
_internal=True,
is_driver_deployment=is_driver_deployment,
_internal=True,
)

# This handles both parametrized and non-parametrized usage of the
Expand Down Expand Up @@ -548,6 +550,7 @@ def run(
"route_prefix": deployment.route_prefix,
"url": deployment.url,
"is_driver_deployment": deployment._is_driver_deployment,
"docs_path": deployment._docs_path,
}
parameter_group.append(deployment_parameters)
client.deploy_group(
Expand Down
9 changes: 9 additions & 0 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,13 @@ def deploy(
replica_config_proto_bytes: bytes,
route_prefix: Optional[str],
deployer_job_id: Union[str, bytes],
docs_path: Optional[str] = None,
is_driver_deployment: Optional[bool] = False,
) -> bool:
if route_prefix is not None:
assert route_prefix.startswith("/")
if docs_path is not None:
assert docs_path.startswith("/")

deployment_config = DeploymentConfig.from_proto_bytes(
deployment_config_proto_bytes
Expand Down Expand Up @@ -719,6 +722,12 @@ def get_deployment_status(self, name: str) -> Union[None, bytes]:
return None
return status[0].to_proto().SerializeToString()

def get_docs_path(self, name: str):
"""Docs path for application.
Currently, this is the OpenAPI docs path for FastAPI-integrated applications."""
return self.application_state_manager.get_docs_path(name)

def delete_apps(self, names: Iterable[str]):
"""Delete applications based on names
Expand Down
12 changes: 11 additions & 1 deletion python/ray/serve/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(
init_kwargs: Optional[Tuple[Any]] = None,
route_prefix: Union[str, None, DEFAULT] = DEFAULT.VALUE,
ray_actor_options: Optional[Dict] = None,
_internal=False,
is_driver_deployment: Optional[bool] = False,
_internal=False,
) -> None:
"""Construct a Deployment. CONSTRUCTOR SHOULDN'T BE USED DIRECTLY.
Expand Down Expand Up @@ -87,6 +87,15 @@ def __init__(
if init_kwargs is None:
init_kwargs = {}

docs_path = None
if (
inspect.isclass(func_or_class)
and hasattr(func_or_class, "__module__")
and func_or_class.__module__ == "ray.serve.api"
and hasattr(func_or_class, "__fastapi_docs_path__")
):
docs_path = func_or_class.__fastapi_docs_path__

self._func_or_class = func_or_class
self._name = name
self._version = version
Expand All @@ -96,6 +105,7 @@ def __init__(
self._route_prefix = route_prefix
self._ray_actor_options = ray_actor_options
self._is_driver_deployment = is_driver_deployment
self._docs_path = docs_path

@property
def name(self) -> str:
Expand Down
71 changes: 70 additions & 1 deletion python/ray/serve/tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
import ray
from ray import serve
from ray.exceptions import GetTimeoutError
from ray.serve.exceptions import RayServeException
from ray.serve._private.client import ServeControllerClient
from ray.serve._private.http_util import make_fastapi_class_based_view
from ray.serve._private.utils import DEFAULT
from ray._private.test_utils import SignalActor
from ray._private.test_utils import SignalActor, wait_for_condition


def test_fastapi_function(serve_instance):
Expand Down Expand Up @@ -664,6 +666,73 @@ def incr(self):
assert resp.json() == [0, 0]


@pytest.mark.parametrize("two_fastapi", [True, False])
def test_two_fastapi_in_one_application(
serve_instance: ServeControllerClient, two_fastapi
):
"""
Check that a deployment graph that would normally work, will not deploy
successfully if there are two FastAPI deployments.
"""
app1 = FastAPI()
app2 = FastAPI()

class SubModel:
def add(self, a: int):
return a + 1

@serve.deployment
@serve.ingress(app1)
class Model:
def __init__(self, submodel):
self.submodel = submodel

@app1.get("/{a}")
async def func(self, a: int):
return await (await self.submodel.add.remote(a))

if two_fastapi:
SubModel = serve.deployment(serve.ingress(app2)(SubModel))
with pytest.raises(RayServeException) as e:
handle = serve.run(Model.bind(SubModel.bind()), name="app1")
assert "FastAPI" in str(e.value)
else:
handle = serve.run(Model.bind(serve.deployment(SubModel).bind()), name="app1")
assert ray.get(handle.func.remote(5)) == 6


@pytest.mark.parametrize(
"is_fastapi,docs_path",
[
(False, None), # Not integrated with FastAPI
(True, "/docs"), # Don't specify docs_url, use default
(True, "/documentation"), # Override default docs url
],
)
def test_fastapi_docs_path(
serve_instance: ServeControllerClient, is_fastapi, docs_path
):
# If not the default docs_url, override it.
if docs_path != "/docs":
app = FastAPI(docs_url=docs_path)
else:
app = FastAPI()

class Model:
@app.get("/{a}")
def func(a: int):
return {"result": a}

if is_fastapi:
Model = serve.ingress(app)(Model)

serve.run(serve.deployment(Model).bind(), name="app1")
wait_for_condition(
lambda: ray.get(serve_instance._controller.get_docs_path.remote("app1"))
== docs_path
)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit fdf9866

Please sign in to comment.