Skip to content

Commit

Permalink
FEAT: Basic cancel support for image model (#2528)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Nov 12, 2024
1 parent e5594a5 commit fe94552
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 14 deletions.
61 changes: 56 additions & 5 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@

from .._compat import BaseModel, Field
from .._version import get_versions
from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS
from ..constants import (
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
XINFERENCE_DEFAULT_ENDPOINT_PORT,
XINFERENCE_DISABLE_METRICS,
)
from ..core.event import Event, EventCollectorActor, EventType
from ..core.supervisor import SupervisorActor
from ..core.utils import json_dumps
from ..core.utils import CancelMixin, json_dumps
from ..types import (
ChatCompletion,
Completion,
Expand Down Expand Up @@ -206,7 +210,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
model_ability: List[str]


class RESTfulAPI:
class RESTfulAPI(CancelMixin):
def __init__(
self,
supervisor_address: str,
Expand Down Expand Up @@ -1531,8 +1535,11 @@ async def create_images(self, request: Request) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

request_id = None
try:
kwargs = json.loads(body.kwargs) if body.kwargs else {}
request_id = kwargs.get("request_id")
self._add_running_task(request_id)
image_list = await model.text_to_image(
prompt=body.prompt,
n=body.n,
Expand All @@ -1541,6 +1548,11 @@ async def create_images(self, request: Request) -> Response:
**kwargs,
)
return Response(content=image_list, media_type="application/json")
except asyncio.CancelledError:
err_str = f"The request has been cancelled: {request_id}"
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
Expand Down Expand Up @@ -1686,11 +1698,14 @@ async def create_variations(
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

request_id = None
try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
request_id = parsed_kwargs.get("request_id")
self._add_running_task(request_id)
image_list = await model_ref.image_to_image(
image=Image.open(image.file),
prompt=prompt,
Expand All @@ -1701,6 +1716,11 @@ async def create_variations(
**parsed_kwargs,
)
return Response(content=image_list, media_type="application/json")
except asyncio.CancelledError:
err_str = f"The request has been cancelled: {request_id}"
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
Expand Down Expand Up @@ -1734,11 +1754,14 @@ async def create_inpainting(
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

request_id = None
try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
request_id = parsed_kwargs.get("request_id")
self._add_running_task(request_id)
im = Image.open(image.file)
mask_im = Image.open(mask_image.file)
if not size:
Expand All @@ -1755,6 +1778,11 @@ async def create_inpainting(
**parsed_kwargs,
)
return Response(content=image_list, media_type="application/json")
except asyncio.CancelledError:
err_str = f"The request has been cancelled: {request_id}"
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
Expand Down Expand Up @@ -1782,17 +1810,25 @@ async def create_ocr(
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

request_id = None
try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
request_id = parsed_kwargs.get("request_id")
self._add_running_task(request_id)
im = Image.open(image.file)
text = await model_ref.ocr(
image=im,
**parsed_kwargs,
)
return Response(content=text, media_type="text/plain")
except asyncio.CancelledError:
err_str = f"The request has been cancelled: {request_id}"
logger.error(err_str)
await self._report_error_event(model_uid, err_str)
raise HTTPException(status_code=409, detail=err_str)
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
Expand Down Expand Up @@ -2111,10 +2147,25 @@ async def get_model_events(self, model_uid: str) -> JSONResponse:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
async def abort_request(
self, request: Request, model_uid: str, request_id: str
) -> JSONResponse:
try:
payload = await request.json()
block_duration = payload.get(
"block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
)
logger.info(
"Abort request with model uid: %s, request id: %s, block duration: %s",
model_uid,
request_id,
block_duration,
)
supervisor_ref = await self._get_supervisor_ref()
res = await supervisor_ref.abort_request(model_uid, request_id)
res = await supervisor_ref.abort_request(
model_uid, request_id, block_duration
)
self._cancel_running_task(request_id, block_duration)
return JSONResponse(content=res)
except Exception as e:
logger.error(e, exc_info=True)
Expand Down
9 changes: 7 additions & 2 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ def query_engine_by_model_name(self, model_name: str):
response_data = response.json()
return response_data

def abort_request(self, model_uid: str, request_id: str):
def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
"""
Abort a request.
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
Expand All @@ -1369,13 +1369,18 @@ def abort_request(self, model_uid: str, request_id: str):
Model uid.
request_id: str
Request id.
block_duration: int
The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
prevent it from taking effect if it arrives before the request operation.
Returns
-------
Dict
Return empty dict.
"""
url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
response = requests.post(url, headers=self._headers)
response = requests.post(
url, headers=self._headers, json={"block_duration": block_duration}
)
if response.status_code != 200:
raise RuntimeError(
f"Failed to abort request, detail: {_get_error_string(response)}"
Expand Down
1 change: 1 addition & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ def get_xinference_home() -> str:
XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
)
XINFERENCE_LAUNCH_MODEL_RETRY = 3
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION = 30
13 changes: 10 additions & 3 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import xoscar as xo

from ..constants import (
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
XINFERENCE_LAUNCH_MODEL_RETRY,
XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
)
Expand All @@ -57,7 +58,7 @@
logger = logging.getLogger(__name__)

from ..device_utils import empty_cache
from .utils import json_dumps, log_async
from .utils import CancelMixin, json_dumps, log_async

try:
from torch.cuda import OutOfMemoryError
Expand Down Expand Up @@ -136,7 +137,7 @@ async def _async_wrapper(*args, **kwargs):
return _wrapper


class ModelActor(xo.StatelessActor):
class ModelActor(xo.StatelessActor, CancelMixin):
_replica_model_uid: Optional[str]

@classmethod
Expand Down Expand Up @@ -553,6 +554,7 @@ async def _call_wrapper_binary(self, fn: Callable, *args, **kwargs):

@oom_check
async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
self._add_running_task(kwargs.get("request_id"))
if self._lock is None:
if inspect.iscoroutinefunction(fn):
ret = await fn(*args, **kwargs)
Expand Down Expand Up @@ -761,9 +763,14 @@ async def chat(self, messages: List[Dict], *args, **kwargs):
prompt_tokens,
)

async def abort_request(self, request_id: str) -> str:
async def abort_request(
self,
request_id: str,
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
) -> str:
from .utils import AbortRequestMessage

self._cancel_running_task(request_id, block_duration)
if self.allow_batching():
if self._scheduler_ref is None:
return AbortRequestMessage.NOT_FOUND.name
Expand Down
10 changes: 8 additions & 2 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import xoscar as xo

from ..constants import (
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
XINFERENCE_DISABLE_HEALTH_CHECK,
XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
XINFERENCE_HEALTH_CHECK_INTERVAL,
Expand Down Expand Up @@ -1213,7 +1214,12 @@ async def list_cached_models(
return cached_models

@log_async(logger=logger)
async def abort_request(self, model_uid: str, request_id: str) -> Dict:
async def abort_request(
self,
model_uid: str,
request_id: str,
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
) -> Dict:
from .scheduler import AbortRequestMessage

res = {"msg": AbortRequestMessage.NO_OP.name}
Expand All @@ -1228,7 +1234,7 @@ async def abort_request(self, model_uid: str, request_id: str) -> Dict:
if worker_ref is None:
continue
model_ref = await worker_ref.get_model(model_uid=rep_mid)
result_info = await model_ref.abort_request(request_id)
result_info = await model_ref.abort_request(request_id, block_duration)
res["msg"] = result_info
if result_info == AbortRequestMessage.DONE.name:
break
Expand Down
69 changes: 67 additions & 2 deletions xinference/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import random
import string
import uuid
import weakref
from enum import Enum
from typing import Dict, Generator, List, Optional, Tuple, Union

import orjson
from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown

from .._compat import BaseModel
from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
from ..constants import (
XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
XINFERENCE_LOG_ARG_MAX_LENGTH,
)

logger = logging.getLogger(__name__)

Expand All @@ -49,13 +54,20 @@ def log_async(
):
import time
from functools import wraps
from inspect import signature

def decorator(func):
func_name = func.__name__
sig = signature(func)

@wraps(func)
async def wrapped(*args, **kwargs):
request_id_str = kwargs.get("request_id", "")
try:
bound_args = sig.bind_partial(*args, **kwargs)
arguments = bound_args.arguments
except TypeError:
arguments = {}
request_id_str = arguments.get("request_id", "")
if not request_id_str:
request_id_str = uuid.uuid1()
if func_name == "text_to_image":
Expand Down Expand Up @@ -269,3 +281,56 @@ def assign_replica_gpu(
if isinstance(gpu_idx, list) and gpu_idx:
return gpu_idx[rep_id::replica]
return gpu_idx


class CancelMixin:
_CANCEL_TASK_NAME = "abort_block"

def __init__(self):
self._running_tasks: weakref.WeakValueDictionary[
str, asyncio.Task
] = weakref.WeakValueDictionary()

def _add_running_task(self, request_id: Optional[str]):
"""Add current asyncio task to the running task.
:param request_id: The corresponding request id.
"""
if request_id is None:
return
running_task = self._running_tasks.get(request_id)
if running_task is not None:
if running_task.get_name() == self._CANCEL_TASK_NAME:
raise Exception(f"The request has been aborted: {request_id}")
raise Exception(f"Duplicate request id: {request_id}")
current_task = asyncio.current_task()
assert current_task is not None
self._running_tasks[request_id] = current_task

def _cancel_running_task(
self,
request_id: Optional[str],
block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
):
"""Cancel the running asyncio task.
:param request_id: The request id to cancel.
:param block_duration: The duration seconds to ensure the request can't be executed.
"""
if request_id is None:
return
running_task = self._running_tasks.pop(request_id, None)
if running_task is not None:
running_task.cancel()

async def block_task():
"""This task is for blocking the request for a duration."""
try:
await asyncio.sleep(block_duration)
logger.info("Abort block end for request: %s", request_id)
except asyncio.CancelledError:
logger.info("Abort block is cancelled for request: %s", request_id)

if block_duration > 0:
logger.info("Abort block start for request: %s", request_id)
self._running_tasks[request_id] = asyncio.create_task(
block_task(), name=self._CANCEL_TASK_NAME
)
Loading

0 comments on commit fe94552

Please sign in to comment.