From 79c1f08c6aa440562bdcf1a5c4e2d08d59b1e724 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Mon, 28 Oct 2024 18:14:05 +0000 Subject: [PATCH 1/2] Implemented health check for deployments based on exceptions in responses. --- aana/deployments/base_deployment.py | 66 +++++++++++++++++ .../base_text_generation_deployment.py | 7 +- .../haystack_component_deployment.py | 3 +- aana/deployments/hf_blip2_deployment.py | 4 +- aana/deployments/hf_pipeline_deployment.py | 3 +- .../hf_text_generation_deployment.py | 2 + aana/deployments/idefics_2_deployment.py | 5 +- ...pyannote_speaker_diarization_deployment.py | 3 +- .../sentence_transformer_deployment.py | 3 +- aana/deployments/vad_deployment.py | 3 +- aana/deployments/vllm_deployment.py | 19 ++++- aana/deployments/whisper_deployment.py | 5 +- aana/tests/units/test_deployment_restart.py | 74 +++++++++++++++++++ 13 files changed, 187 insertions(+), 10 deletions(-) create mode 100644 aana/tests/units/test_deployment_restart.py diff --git a/aana/deployments/base_deployment.py b/aana/deployments/base_deployment.py index 458744d9..d0ee5155 100644 --- a/aana/deployments/base_deployment.py +++ b/aana/deployments/base_deployment.py @@ -1,6 +1,44 @@ import inspect +from functools import wraps from typing import Any +from aana.exceptions.runtime import InferenceException + + +def exception_handler(func): + """AanaDeploymentHandle decorator to catch exceptions and store them in the deployment for health check purposes. + + Args: + func (function): The function to decorate. + + Returns: + function: The decorated function + """ + + @wraps(func) + async def wrapper(self, *args, **kwargs): + self.num_requests_since_last_health_check += 1 + try: + return await func(self, *args, **kwargs) + except Exception as e: + self.raised_exceptions.append(e) + raise + + @wraps(func) + async def wrapper_generator(self, *args, **kwargs): + self.num_requests_since_last_health_check += 1 + try: + async for item in func(self, *args, **kwargs): + yield item + except Exception as e: + self.raised_exceptions.append(e) + raise + + if inspect.isasyncgenfunction(func): + return wrapper_generator + else: + return wrapper + class BaseDeployment: """Base class for all deployments. @@ -13,6 +51,9 @@ def __init__(self): """Inits to unconfigured state.""" self.config = None self._configured = False + self.num_requests_since_last_health_check = 0 + self.raised_exceptions = [] + self.restart_exceptions = [InferenceException] async def reconfigure(self, config: dict[str, Any]): """Reconfigure the deployment. @@ -22,6 +63,31 @@ async def reconfigure(self, config: dict[str, Any]): self.config = config await self.apply_config(config) self._configured = True + if "restart_exceptions" in config: + self.restart_exceptions = config["restart_exceptions"] + + async def check_health(self): + """Check the health of the deployment. + + Raises: + Raises the exception that caused the deployment to be unhealthy. + """ + raised_restart_exceptions = [ + exception + for exception in self.raised_exceptions + if exception.__class__ in self.restart_exceptions + ] + # Restart the deployment if more than 50% of the requests raised restart exceptions + if self.num_requests_since_last_health_check != 0: + ratio_restart_exceptions = ( + len(raised_restart_exceptions) + / self.num_requests_since_last_health_check + ) + if ratio_restart_exceptions > 0.5: + raise raised_restart_exceptions[0] + + self.raised_exceptions = [] + self.num_requests_since_last_health_check = 0 async def apply_config(self, config: dict[str, Any]): """Apply the configuration. diff --git a/aana/deployments/base_text_generation_deployment.py b/aana/deployments/base_text_generation_deployment.py index 44f51656..49b56961 100644 --- a/aana/deployments/base_text_generation_deployment.py +++ b/aana/deployments/base_text_generation_deployment.py @@ -9,7 +9,7 @@ from aana.core.chat.chat_template import apply_chat_template from aana.core.models.chat import ChatDialog, ChatMessage from aana.core.models.sampling import SamplingParams -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler class LLMOutput(TypedDict): @@ -57,6 +57,7 @@ class BaseTextGenerationDeployment(BaseDeployment): You can also override these methods to implement custom inference logic. """ + @exception_handler async def generate_stream( self, prompt: str, sampling_params: SamplingParams | None = None ) -> AsyncGenerator[LLMOutput, None]: @@ -71,6 +72,7 @@ async def generate_stream( """ raise NotImplementedError + @exception_handler async def generate( self, prompt: str, sampling_params: SamplingParams | None = None ) -> LLMOutput: @@ -88,6 +90,7 @@ async def generate( generated_text += chunk["text"] return LLMOutput(text=generated_text) + @exception_handler async def generate_batch( self, prompts: list[str], sampling_params: SamplingParams | None = None ) -> LLMBatchOutput: @@ -108,6 +111,7 @@ async def generate_batch( return LLMBatchOutput(texts=texts) + @exception_handler async def chat( self, dialog: ChatDialog, sampling_params: SamplingParams | None = None ) -> ChatOutput: @@ -127,6 +131,7 @@ async def chat( response_message = ChatMessage(content=response["text"], role="assistant") return ChatOutput(message=response_message) + @exception_handler async def chat_stream( self, dialog: ChatDialog, sampling_params: SamplingParams | None = None ) -> AsyncGenerator[LLMOutput, None]: diff --git a/aana/deployments/haystack_component_deployment.py b/aana/deployments/haystack_component_deployment.py index 238da634..f411961b 100644 --- a/aana/deployments/haystack_component_deployment.py +++ b/aana/deployments/haystack_component_deployment.py @@ -5,7 +5,7 @@ from ray import serve from aana.deployments.aana_deployment_handle import AanaDeploymentHandle -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.utils.asyncio import run_async from aana.utils.core import import_from_path @@ -84,6 +84,7 @@ async def apply_config(self, config: dict[str, Any]): self.component.warm_up() + @exception_handler async def run(self, **data: dict[str, Any]) -> dict[str, Any]: """Run the model on the input data.""" return self.component.run(**data) diff --git a/aana/deployments/hf_blip2_deployment.py b/aana/deployments/hf_blip2_deployment.py index 44580c5f..0a2bdd95 100644 --- a/aana/deployments/hf_blip2_deployment.py +++ b/aana/deployments/hf_blip2_deployment.py @@ -11,7 +11,7 @@ from aana.core.models.captions import Caption, CaptionsList from aana.core.models.image import Image from aana.core.models.types import Dtype -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.exceptions.runtime import InferenceException from aana.processors.batch import BatchProcessor @@ -106,6 +106,7 @@ async def apply_config(self, config: dict[str, Any]): self.processor = Blip2Processor.from_pretrained(self.model_id) self.model.to(self.device) + @exception_handler async def generate(self, image: Image) -> CaptioningOutput: """Generate captions for the given image. @@ -124,6 +125,7 @@ async def generate(self, image: Image) -> CaptioningOutput: ) return CaptioningOutput(caption=captions["captions"][0]) + @exception_handler async def generate_batch(self, **kwargs) -> CaptioningBatchOutput: """Generate captions for the given images. diff --git a/aana/deployments/hf_pipeline_deployment.py b/aana/deployments/hf_pipeline_deployment.py index 1e8b4880..56b0ecb4 100644 --- a/aana/deployments/hf_pipeline_deployment.py +++ b/aana/deployments/hf_pipeline_deployment.py @@ -9,7 +9,7 @@ from aana.core.models.base import pydantic_protected_fields from aana.core.models.custom_config import CustomConfig from aana.core.models.image import Image -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler class HfPipelineConfig(BaseModel): @@ -80,6 +80,7 @@ async def apply_config(self, config: dict[str, Any]): else: raise + @exception_handler async def call(self, *args, **kwargs): """Call the pipeline. diff --git a/aana/deployments/hf_text_generation_deployment.py b/aana/deployments/hf_text_generation_deployment.py index 97c4f760..0d4be92b 100644 --- a/aana/deployments/hf_text_generation_deployment.py +++ b/aana/deployments/hf_text_generation_deployment.py @@ -14,6 +14,7 @@ from aana.core.models.base import merged_options, pydantic_protected_fields from aana.core.models.sampling import SamplingParams +from aana.deployments.base_deployment import exception_handler from aana.deployments.base_text_generation_deployment import ( BaseTextGenerationDeployment, LLMOutput, @@ -48,6 +49,7 @@ class HfTextGenerationConfig(BaseModel): class BaseHfTextGenerationDeployment(BaseTextGenerationDeployment): """Base class for Hugging Face text generation deployments.""" + @exception_handler async def generate_stream( self, prompt: str, sampling_params: SamplingParams | None = None ) -> AsyncGenerator[LLMOutput, None]: diff --git a/aana/deployments/idefics_2_deployment.py b/aana/deployments/idefics_2_deployment.py index d6ca1d59..d03f3873 100644 --- a/aana/deployments/idefics_2_deployment.py +++ b/aana/deployments/idefics_2_deployment.py @@ -19,7 +19,7 @@ from aana.core.models.image_chat import ImageChatDialog from aana.core.models.sampling import SamplingParams from aana.core.models.types import Dtype -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.deployments.base_text_generation_deployment import ChatOutput, LLMOutput from aana.exceptions.runtime import InferenceException from aana.utils.streamer import async_streamer_adapter @@ -88,6 +88,7 @@ async def apply_config(self, config: dict[str, Any]): self.model_id, **self.model_kwargs ) + @exception_handler async def chat_stream( self, dialog: ImageChatDialog, sampling_params: SamplingParams | None = None ) -> AsyncGenerator[LLMOutput, None]: @@ -153,6 +154,7 @@ async def chat_stream( except Exception as e: raise InferenceException(model_name=self.model_id) from e + @exception_handler async def chat( self, dialog: ImageChatDialog, sampling_params: SamplingParams | None = None ) -> ChatOutput: @@ -171,6 +173,7 @@ async def chat( return ChatOutput(message=ChatMessage(content=text, role="assistant")) + @exception_handler async def chat_batch( self, dialogs: list[ImageChatDialog], diff --git a/aana/deployments/pyannote_speaker_diarization_deployment.py b/aana/deployments/pyannote_speaker_diarization_deployment.py index 5e494e82..081b4a50 100644 --- a/aana/deployments/pyannote_speaker_diarization_deployment.py +++ b/aana/deployments/pyannote_speaker_diarization_deployment.py @@ -14,7 +14,7 @@ SpeakerDiarizationSegment, ) from aana.core.models.time import TimeInterval -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.exceptions.runtime import InferenceException from aana.processors.speaker import combine_homogeneous_speaker_diarization_segments @@ -116,6 +116,7 @@ async def __inference( return speaker_segments + @exception_handler async def diarize( self, audio: Audio, params: PyannoteSpeakerDiarizationParams | None = None ) -> SpeakerDiarizationOutput: diff --git a/aana/deployments/sentence_transformer_deployment.py b/aana/deployments/sentence_transformer_deployment.py index ea9dd78a..c73ee1c1 100644 --- a/aana/deployments/sentence_transformer_deployment.py +++ b/aana/deployments/sentence_transformer_deployment.py @@ -7,7 +7,7 @@ from typing_extensions import TypedDict from aana.core.models.base import pydantic_protected_fields -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.exceptions.runtime import InferenceException from aana.processors.batch import BatchProcessor @@ -70,6 +70,7 @@ async def apply_config(self, config: dict[str, Any]): self.model = SentenceTransformer(self.model_id) + @exception_handler async def embed_batch(self, **kwargs) -> np.ndarray: """Embed the given sentences. diff --git a/aana/deployments/vad_deployment.py b/aana/deployments/vad_deployment.py index 4e0fa4d9..b1a99845 100644 --- a/aana/deployments/vad_deployment.py +++ b/aana/deployments/vad_deployment.py @@ -10,7 +10,7 @@ from aana.core.models.base import pydantic_protected_fields from aana.core.models.time import TimeInterval from aana.core.models.vad import VadParams, VadSegment -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.exceptions.runtime import InferenceException from aana.processors.vad import BinarizeVadScores, VoiceActivitySegmentation from aana.utils.download import download_model @@ -211,6 +211,7 @@ async def __inference(self, audio: Audio) -> list[dict]: return vad_segments + @exception_handler async def asr_preprocess_vad( self, audio: Audio, params: VadParams | None = None ) -> VadOutput: diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index d575b04d..32428926 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -25,7 +25,7 @@ from aana.core.models.image_chat import ImageChatDialog from aana.core.models.sampling import SamplingParams from aana.core.models.types import Dtype -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.deployments.base_text_generation_deployment import ( ChatOutput, LLMBatchOutput, @@ -72,6 +72,11 @@ class VLLMConfig(BaseModel): class VLLMDeployment(BaseDeployment): """Deployment to serve large language models using vLLM.""" + def __init__(self): + """Initialize the deployment.""" + super().__init__() + self.engine = None + async def apply_config(self, config: dict[str, Any]): """Apply the configuration. @@ -123,6 +128,13 @@ async def apply_config(self, config: dict[str, Any]): self.tokenizer = self.engine.engine.tokenizer.tokenizer self.model_config = await self.engine.get_model_config() + async def check_health(self): + """Check the health of the deployment.""" + if self.engine: + await self.engine.check_health() + + await super().check_health() + def apply_chat_template( self, dialog: ChatDialog | ImageChatDialog ) -> tuple[str | list[int], dict | None]: @@ -192,6 +204,7 @@ def replace_image_type(messages: list[dict], images: list[Image]) -> list[dict]: ) return prompt, mm_data + @exception_handler async def generate_stream( # noqa: C901 self, prompt: str | list[int], @@ -274,6 +287,7 @@ async def generate_stream( # noqa: C901 except Exception as e: raise InferenceException(model_name=self.model_id) from e + @exception_handler async def generate( self, prompt: str | list[int], @@ -297,6 +311,7 @@ async def generate( generated_text += chunk["text"] return LLMOutput(text=generated_text) + @exception_handler async def generate_batch( self, prompts: list[str] | list[list[int]], @@ -326,6 +341,7 @@ async def generate_batch( return LLMBatchOutput(texts=texts) + @exception_handler async def chat( self, dialog: ChatDialog | ImageChatDialog, @@ -349,6 +365,7 @@ async def chat( response_message = ChatMessage(content=response["text"], role="assistant") return ChatOutput(message=response_message) + @exception_handler async def chat_stream( self, dialog: ChatDialog | ImageChatDialog, diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index c9040f38..688cba4f 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -18,7 +18,7 @@ from aana.core.models.whisper import ( WhisperParams, ) -from aana.deployments.base_deployment import BaseDeployment +from aana.deployments.base_deployment import BaseDeployment, exception_handler from aana.exceptions.runtime import InferenceException @@ -150,6 +150,7 @@ async def apply_config(self, config: dict[str, Any]): self.model_size, device=self.device, compute_type=self.compute_type ) + @exception_handler async def transcribe( self, audio: Audio, params: WhisperParams | None = None ) -> WhisperOutput: @@ -199,6 +200,7 @@ async def transcribe( transcription=asr_transcription, ) + @exception_handler async def transcribe_stream( self, audio: Audio, params: WhisperParams | None = None ) -> AsyncGenerator[WhisperOutput, None]: @@ -246,6 +248,7 @@ async def transcribe_stream( transcription=asr_transcription, ) + @exception_handler async def transcribe_batch( self, audio_batch: list[Audio], params: WhisperParams | None = None ) -> WhisperBatchOutput: diff --git a/aana/tests/units/test_deployment_restart.py b/aana/tests/units/test_deployment_restart.py new file mode 100644 index 00000000..718a1f72 --- /dev/null +++ b/aana/tests/units/test_deployment_restart.py @@ -0,0 +1,74 @@ +# ruff: noqa: S101, S113 +import asyncio + +import pytest +from ray import serve + +from aana.deployments.aana_deployment_handle import AanaDeploymentHandle +from aana.deployments.base_deployment import BaseDeployment, exception_handler +from aana.exceptions.runtime import InferenceException + + +@serve.deployment(health_check_period_s=1, health_check_timeout_s=30) +class Lowercase(BaseDeployment): + """Ray deployment that returns the lowercase version of a text.""" + + def __init__(self): + """Initialize the deployment.""" + super().__init__() + self.active = True + + @exception_handler + async def lower(self, text: str) -> dict: + """Lowercase the text. + + Args: + text (str): The text to lowercase + + Returns: + dict: The lowercase text + """ + if text == "inference_exception" or not self.active: + self.active = False + raise InferenceException(model_name="lowercase_deployment") + + return {"text": text.lower()} + + +deployments = [ + { + "name": "lowercase_deployment", + "instance": Lowercase, + } +] + + +@pytest.mark.asyncio +async def test_deployment_restart(create_app): + """Test the Ray Serve app.""" + create_app(deployments, []) + + handle = await AanaDeploymentHandle.create("lowercase_deployment") + + text = "Hello, World!" + + # test the lowercase deployment works + response = await handle.lower(text=text) + assert response == {"text": text.lower()} + + # Cause an InferenceException in the deployment and make it inactive. + # After the deployment is inactive, the deployment should always raise an InferenceException. + with pytest.raises(InferenceException): + await handle.lower(text="inference_exception") + + # The deployment should restart and work again, wait for around 60 seconds for the deployment to restart. + for _ in range(60): + await asyncio.sleep(1) + try: + response = await handle.lower(text=text) + if response == {"text": text.lower()}: + break + except: # noqa: S110 + pass + + assert response == {"text": text.lower()} From c0282790108c51d974394ab4f634d6dab6b2caad Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 29 Oct 2024 11:40:20 +0000 Subject: [PATCH 2/2] Removed exception_handler from some methods to prevent double count. Remove restart_exceptions setting from reconfigure. --- aana/deployments/base_deployment.py | 2 -- aana/deployments/idefics_2_deployment.py | 1 - aana/deployments/vllm_deployment.py | 4 ---- aana/deployments/whisper_deployment.py | 1 - 4 files changed, 8 deletions(-) diff --git a/aana/deployments/base_deployment.py b/aana/deployments/base_deployment.py index d0ee5155..2cf8b8d9 100644 --- a/aana/deployments/base_deployment.py +++ b/aana/deployments/base_deployment.py @@ -63,8 +63,6 @@ async def reconfigure(self, config: dict[str, Any]): self.config = config await self.apply_config(config) self._configured = True - if "restart_exceptions" in config: - self.restart_exceptions = config["restart_exceptions"] async def check_health(self): """Check the health of the deployment. diff --git a/aana/deployments/idefics_2_deployment.py b/aana/deployments/idefics_2_deployment.py index d03f3873..6536bf88 100644 --- a/aana/deployments/idefics_2_deployment.py +++ b/aana/deployments/idefics_2_deployment.py @@ -154,7 +154,6 @@ async def chat_stream( except Exception as e: raise InferenceException(model_name=self.model_id) from e - @exception_handler async def chat( self, dialog: ImageChatDialog, sampling_params: SamplingParams | None = None ) -> ChatOutput: diff --git a/aana/deployments/vllm_deployment.py b/aana/deployments/vllm_deployment.py index 32428926..f068f31a 100644 --- a/aana/deployments/vllm_deployment.py +++ b/aana/deployments/vllm_deployment.py @@ -287,7 +287,6 @@ async def generate_stream( # noqa: C901 except Exception as e: raise InferenceException(model_name=self.model_id) from e - @exception_handler async def generate( self, prompt: str | list[int], @@ -311,7 +310,6 @@ async def generate( generated_text += chunk["text"] return LLMOutput(text=generated_text) - @exception_handler async def generate_batch( self, prompts: list[str] | list[list[int]], @@ -341,7 +339,6 @@ async def generate_batch( return LLMBatchOutput(texts=texts) - @exception_handler async def chat( self, dialog: ChatDialog | ImageChatDialog, @@ -365,7 +362,6 @@ async def chat( response_message = ChatMessage(content=response["text"], role="assistant") return ChatOutput(message=response_message) - @exception_handler async def chat_stream( self, dialog: ChatDialog | ImageChatDialog, diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index 688cba4f..24d67300 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -248,7 +248,6 @@ async def transcribe_stream( transcription=asr_transcription, ) - @exception_handler async def transcribe_batch( self, audio_batch: list[Audio], params: WhisperParams | None = None ) -> WhisperBatchOutput: