diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index ed3a2eab90..d9d4678911 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -52,10 +52,14 @@ from .._compat import BaseModel, Field from .._version import get_versions -from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS +from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + XINFERENCE_DEFAULT_ENDPOINT_PORT, + XINFERENCE_DISABLE_METRICS, +) from ..core.event import Event, EventCollectorActor, EventType from ..core.supervisor import SupervisorActor -from ..core.utils import json_dumps +from ..core.utils import CancelMixin, json_dumps from ..types import ( ChatCompletion, Completion, @@ -206,7 +210,7 @@ class BuildGradioImageInterfaceRequest(BaseModel): model_ability: List[str] -class RESTfulAPI: +class RESTfulAPI(CancelMixin): def __init__( self, supervisor_address: str, @@ -1531,8 +1535,11 @@ async def create_images(self, request: Request) -> Response: await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: kwargs = json.loads(body.kwargs) if body.kwargs else {} + request_id = kwargs.get("request_id") + self._add_running_task(request_id) image_list = await model.text_to_image( prompt=body.prompt, n=body.n, @@ -1541,6 +1548,11 @@ async def create_images(self, request: Request) -> Response: **kwargs, ) return Response(content=image_list, media_type="application/json") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -1686,11 +1698,14 @@ async def create_variations( await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: if kwargs is not None: parsed_kwargs = json.loads(kwargs) else: parsed_kwargs = {} + request_id = parsed_kwargs.get("request_id") + self._add_running_task(request_id) image_list = await model_ref.image_to_image( image=Image.open(image.file), prompt=prompt, @@ -1701,6 +1716,11 @@ async def create_variations( **parsed_kwargs, ) return Response(content=image_list, media_type="application/json") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -1734,11 +1754,14 @@ async def create_inpainting( await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: if kwargs is not None: parsed_kwargs = json.loads(kwargs) else: parsed_kwargs = {} + request_id = parsed_kwargs.get("request_id") + self._add_running_task(request_id) im = Image.open(image.file) mask_im = Image.open(mask_image.file) if not size: @@ -1755,6 +1778,11 @@ async def create_inpainting( **parsed_kwargs, ) return Response(content=image_list, media_type="application/json") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -1782,17 +1810,25 @@ async def create_ocr( await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + request_id = None try: if kwargs is not None: parsed_kwargs = json.loads(kwargs) else: parsed_kwargs = {} + request_id = parsed_kwargs.get("request_id") + self._add_running_task(request_id) im = Image.open(image.file) text = await model_ref.ocr( image=im, **parsed_kwargs, ) return Response(content=text, media_type="text/plain") + except asyncio.CancelledError: + err_str = f"The request has been cancelled: {request_id}" + logger.error(err_str) + await self._report_error_event(model_uid, err_str) + raise HTTPException(status_code=409, detail=err_str) except RuntimeError as re: logger.error(re, exc_info=True) await self._report_error_event(model_uid, str(re)) @@ -2111,10 +2147,25 @@ async def get_model_events(self, model_uid: str) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse: + async def abort_request( + self, request: Request, model_uid: str, request_id: str + ) -> JSONResponse: try: + payload = await request.json() + block_duration = payload.get( + "block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION + ) + logger.info( + "Abort request with model uid: %s, request id: %s, block duration: %s", + model_uid, + request_id, + block_duration, + ) supervisor_ref = await self._get_supervisor_ref() - res = await supervisor_ref.abort_request(model_uid, request_id) + res = await supervisor_ref.abort_request( + model_uid, request_id, block_duration + ) + self._cancel_running_task(request_id, block_duration) return JSONResponse(content=res) except Exception as e: logger.error(e, exc_info=True) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index dd5e3f1146..ed71a7bf05 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -1357,7 +1357,7 @@ def query_engine_by_model_name(self, model_name: str): response_data = response.json() return response_data - def abort_request(self, model_uid: str, request_id: str): + def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30): """ Abort a request. Abort a submitted request. If the request is finished or not found, this method will be a no-op. @@ -1369,13 +1369,18 @@ def abort_request(self, model_uid: str, request_id: str): Model uid. request_id: str Request id. + block_duration: int + The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may + prevent it from taking effect if it arrives before the request operation. Returns ------- Dict Return empty dict. """ url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort" - response = requests.post(url, headers=self._headers) + response = requests.post( + url, headers=self._headers, json={"block_duration": block_duration} + ) if response.status_code != 200: raise RuntimeError( f"Failed to abort request, detail: {_get_error_string(response)}" diff --git a/xinference/constants.py b/xinference/constants.py index 93e73e4d58..dd0adcb864 100644 --- a/xinference/constants.py +++ b/xinference/constants.py @@ -88,3 +88,4 @@ def get_xinference_home() -> str: XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None ) XINFERENCE_LAUNCH_MODEL_RETRY = 3 +XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION = 30 diff --git a/xinference/core/model.py b/xinference/core/model.py index e911c71e6d..42453ddc69 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -41,6 +41,7 @@ import xoscar as xo from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, XINFERENCE_LAUNCH_MODEL_RETRY, XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE, ) @@ -57,7 +58,7 @@ logger = logging.getLogger(__name__) from ..device_utils import empty_cache -from .utils import json_dumps, log_async +from .utils import CancelMixin, json_dumps, log_async try: from torch.cuda import OutOfMemoryError @@ -136,7 +137,7 @@ async def _async_wrapper(*args, **kwargs): return _wrapper -class ModelActor(xo.StatelessActor): +class ModelActor(xo.StatelessActor, CancelMixin): _replica_model_uid: Optional[str] @classmethod @@ -553,6 +554,7 @@ async def _call_wrapper_binary(self, fn: Callable, *args, **kwargs): @oom_check async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs): + self._add_running_task(kwargs.get("request_id")) if self._lock is None: if inspect.iscoroutinefunction(fn): ret = await fn(*args, **kwargs) @@ -761,9 +763,14 @@ async def chat(self, messages: List[Dict], *args, **kwargs): prompt_tokens, ) - async def abort_request(self, request_id: str) -> str: + async def abort_request( + self, + request_id: str, + block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + ) -> str: from .utils import AbortRequestMessage + self._cancel_running_task(request_id, block_duration) if self.allow_batching(): if self._scheduler_ref is None: return AbortRequestMessage.NOT_FOUND.name diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 8f705217a3..c8f2f59ff6 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -35,6 +35,7 @@ import xoscar as xo from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, XINFERENCE_DISABLE_HEALTH_CHECK, XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD, XINFERENCE_HEALTH_CHECK_INTERVAL, @@ -1213,7 +1214,12 @@ async def list_cached_models( return cached_models @log_async(logger=logger) - async def abort_request(self, model_uid: str, request_id: str) -> Dict: + async def abort_request( + self, + model_uid: str, + request_id: str, + block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + ) -> Dict: from .scheduler import AbortRequestMessage res = {"msg": AbortRequestMessage.NO_OP.name} @@ -1228,7 +1234,7 @@ async def abort_request(self, model_uid: str, request_id: str) -> Dict: if worker_ref is None: continue model_ref = await worker_ref.get_model(model_uid=rep_mid) - result_info = await model_ref.abort_request(request_id) + result_info = await model_ref.abort_request(request_id, block_duration) res["msg"] = result_info if result_info == AbortRequestMessage.DONE.name: break diff --git a/xinference/core/utils.py b/xinference/core/utils.py index d4caba8c54..278c570b20 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import os import random import string import uuid +import weakref from enum import Enum from typing import Dict, Generator, List, Optional, Tuple, Union @@ -23,7 +25,10 @@ from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown from .._compat import BaseModel -from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH +from ..constants import ( + XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + XINFERENCE_LOG_ARG_MAX_LENGTH, +) logger = logging.getLogger(__name__) @@ -49,13 +54,20 @@ def log_async( ): import time from functools import wraps + from inspect import signature def decorator(func): func_name = func.__name__ + sig = signature(func) @wraps(func) async def wrapped(*args, **kwargs): - request_id_str = kwargs.get("request_id", "") + try: + bound_args = sig.bind_partial(*args, **kwargs) + arguments = bound_args.arguments + except TypeError: + arguments = {} + request_id_str = arguments.get("request_id", "") if not request_id_str: request_id_str = uuid.uuid1() if func_name == "text_to_image": @@ -269,3 +281,56 @@ def assign_replica_gpu( if isinstance(gpu_idx, list) and gpu_idx: return gpu_idx[rep_id::replica] return gpu_idx + + +class CancelMixin: + _CANCEL_TASK_NAME = "abort_block" + + def __init__(self): + self._running_tasks: weakref.WeakValueDictionary[ + str, asyncio.Task + ] = weakref.WeakValueDictionary() + + def _add_running_task(self, request_id: Optional[str]): + """Add current asyncio task to the running task. + :param request_id: The corresponding request id. + """ + if request_id is None: + return + running_task = self._running_tasks.get(request_id) + if running_task is not None: + if running_task.get_name() == self._CANCEL_TASK_NAME: + raise Exception(f"The request has been aborted: {request_id}") + raise Exception(f"Duplicate request id: {request_id}") + current_task = asyncio.current_task() + assert current_task is not None + self._running_tasks[request_id] = current_task + + def _cancel_running_task( + self, + request_id: Optional[str], + block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION, + ): + """Cancel the running asyncio task. + :param request_id: The request id to cancel. + :param block_duration: The duration seconds to ensure the request can't be executed. + """ + if request_id is None: + return + running_task = self._running_tasks.pop(request_id, None) + if running_task is not None: + running_task.cancel() + + async def block_task(): + """This task is for blocking the request for a duration.""" + try: + await asyncio.sleep(block_duration) + logger.info("Abort block end for request: %s", request_id) + except asyncio.CancelledError: + logger.info("Abort block is cancelled for request: %s", request_id) + + if block_duration > 0: + logger.info("Abort block start for request: %s", request_id) + self._running_tasks[request_id] = asyncio.create_task( + block_task(), name=self._CANCEL_TASK_NAME + ) diff --git a/xinference/model/image/tests/test_stable_diffusion.py b/xinference/model/image/tests/test_stable_diffusion.py index e4da8014d0..04cb607201 100644 --- a/xinference/model/image/tests/test_stable_diffusion.py +++ b/xinference/model/image/tests/test_stable_diffusion.py @@ -18,6 +18,8 @@ import os.path import shutil import tempfile +import threading +import time import uuid from io import BytesIO @@ -195,6 +197,62 @@ def test_restful_api_for_image_with_mlsd_controlnet(setup): logger.info("test result %s", r) +@pytest.mark.parametrize("model_name", ["sd-turbo"]) +def test_restful_api_abort(setup, model_name): + endpoint, _ = setup + from ....client import Client + + client = Client(endpoint) + + model_uid = client.launch_model( + model_uid="my_controlnet", + model_name=model_name, + model_type="image", + ) + model = client.get_model(model_uid) + + request_id = str(uuid.uuid4()) + client.abort_request(model_uid, request_id, 1) + time.sleep(2) + r = model.text_to_image( + prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.", + size="512*512", + num_inference_steps=10, + request_id=request_id, + ) + assert "created" in r + + request_id = str(uuid.uuid4()) + client.abort_request(model_uid, request_id) + with pytest.raises( + RuntimeError, match=f"The request has been aborted: {request_id}" + ): + model.text_to_image( + prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.", + size="512*512", + num_inference_steps=10, + request_id=request_id, + ) + + request_id = str(uuid.uuid4()) + + def _abort(): + time.sleep(1) + client.abort_request(model_uid, request_id) + + t = threading.Thread(target=_abort) + t.start() + with pytest.raises( + RuntimeError, match=f"The request has been cancelled: {request_id}" + ): + model.text_to_image( + prompt="A cinematic shot of a baby raccoon wearing an intricate italian priest robe.", + size="512*512", + num_inference_steps=10, + request_id=request_id, + ) + + @pytest.mark.parametrize("model_name", ["sd-turbo", "sdxl-turbo"]) def test_restful_api_for_sd_turbo(setup, model_name): if model_name == "sdxl-turbo":