Skip to content

Commit

Permalink
Merge pull request #191 from mobiusml/deployment_health_check
Browse files Browse the repository at this point in the history
Deployment Health Check and Automatic Restart
  • Loading branch information
movchan74 authored Oct 29, 2024
2 parents 324f8f9 + c028279 commit 580090c
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 10 deletions.
64 changes: 64 additions & 0 deletions aana/deployments/base_deployment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
import inspect
from functools import wraps
from typing import Any

from aana.exceptions.runtime import InferenceException


def exception_handler(func):
"""AanaDeploymentHandle decorator to catch exceptions and store them in the deployment for health check purposes.
Args:
func (function): The function to decorate.
Returns:
function: The decorated function
"""

@wraps(func)
async def wrapper(self, *args, **kwargs):
self.num_requests_since_last_health_check += 1
try:
return await func(self, *args, **kwargs)
except Exception as e:
self.raised_exceptions.append(e)
raise

@wraps(func)
async def wrapper_generator(self, *args, **kwargs):
self.num_requests_since_last_health_check += 1
try:
async for item in func(self, *args, **kwargs):
yield item
except Exception as e:
self.raised_exceptions.append(e)
raise

if inspect.isasyncgenfunction(func):
return wrapper_generator
else:
return wrapper


class BaseDeployment:
"""Base class for all deployments.
Expand All @@ -13,6 +51,9 @@ def __init__(self):
"""Inits to unconfigured state."""
self.config = None
self._configured = False
self.num_requests_since_last_health_check = 0
self.raised_exceptions = []
self.restart_exceptions = [InferenceException]

async def reconfigure(self, config: dict[str, Any]):
"""Reconfigure the deployment.
Expand All @@ -23,6 +64,29 @@ async def reconfigure(self, config: dict[str, Any]):
await self.apply_config(config)
self._configured = True

async def check_health(self):
"""Check the health of the deployment.
Raises:
Raises the exception that caused the deployment to be unhealthy.
"""
raised_restart_exceptions = [
exception
for exception in self.raised_exceptions
if exception.__class__ in self.restart_exceptions
]
# Restart the deployment if more than 50% of the requests raised restart exceptions
if self.num_requests_since_last_health_check != 0:
ratio_restart_exceptions = (
len(raised_restart_exceptions)
/ self.num_requests_since_last_health_check
)
if ratio_restart_exceptions > 0.5:
raise raised_restart_exceptions[0]

self.raised_exceptions = []
self.num_requests_since_last_health_check = 0

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
Expand Down
7 changes: 6 additions & 1 deletion aana/deployments/base_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aana.core.chat.chat_template import apply_chat_template
from aana.core.models.chat import ChatDialog, ChatMessage
from aana.core.models.sampling import SamplingParams
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler


class LLMOutput(TypedDict):
Expand Down Expand Up @@ -57,6 +57,7 @@ class BaseTextGenerationDeployment(BaseDeployment):
You can also override these methods to implement custom inference logic.
"""

@exception_handler
async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand All @@ -71,6 +72,7 @@ async def generate_stream(
"""
raise NotImplementedError

@exception_handler
async def generate(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> LLMOutput:
Expand All @@ -88,6 +90,7 @@ async def generate(
generated_text += chunk["text"]
return LLMOutput(text=generated_text)

@exception_handler
async def generate_batch(
self, prompts: list[str], sampling_params: SamplingParams | None = None
) -> LLMBatchOutput:
Expand All @@ -108,6 +111,7 @@ async def generate_batch(

return LLMBatchOutput(texts=texts)

@exception_handler
async def chat(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> ChatOutput:
Expand All @@ -127,6 +131,7 @@ async def chat(
response_message = ChatMessage(content=response["text"], role="assistant")
return ChatOutput(message=response_message)

@exception_handler
async def chat_stream(
self, dialog: ChatDialog, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/haystack_component_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray import serve

from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.utils.asyncio import run_async
from aana.utils.core import import_from_path

Expand Down Expand Up @@ -84,6 +84,7 @@ async def apply_config(self, config: dict[str, Any]):

self.component.warm_up()

@exception_handler
async def run(self, **data: dict[str, Any]) -> dict[str, Any]:
"""Run the model on the input data."""
return self.component.run(**data)
Expand Down
4 changes: 3 additions & 1 deletion aana/deployments/hf_blip2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aana.core.models.captions import Caption, CaptionsList
from aana.core.models.image import Image
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.batch import BatchProcessor

Expand Down Expand Up @@ -106,6 +106,7 @@ async def apply_config(self, config: dict[str, Any]):
self.processor = Blip2Processor.from_pretrained(self.model_id)
self.model.to(self.device)

@exception_handler
async def generate(self, image: Image) -> CaptioningOutput:
"""Generate captions for the given image.
Expand All @@ -124,6 +125,7 @@ async def generate(self, image: Image) -> CaptioningOutput:
)
return CaptioningOutput(caption=captions["captions"][0])

@exception_handler
async def generate_batch(self, **kwargs) -> CaptioningBatchOutput:
"""Generate captions for the given images.
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/hf_pipeline_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aana.core.models.base import pydantic_protected_fields
from aana.core.models.custom_config import CustomConfig
from aana.core.models.image import Image
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler


class HfPipelineConfig(BaseModel):
Expand Down Expand Up @@ -80,6 +80,7 @@ async def apply_config(self, config: dict[str, Any]):
else:
raise

@exception_handler
async def call(self, *args, **kwargs):
"""Call the pipeline.
Expand Down
2 changes: 2 additions & 0 deletions aana/deployments/hf_text_generation_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aana.core.models.base import merged_options, pydantic_protected_fields
from aana.core.models.sampling import SamplingParams
from aana.deployments.base_deployment import exception_handler
from aana.deployments.base_text_generation_deployment import (
BaseTextGenerationDeployment,
LLMOutput,
Expand Down Expand Up @@ -48,6 +49,7 @@ class HfTextGenerationConfig(BaseModel):
class BaseHfTextGenerationDeployment(BaseTextGenerationDeployment):
"""Base class for Hugging Face text generation deployments."""

@exception_handler
async def generate_stream(
self, prompt: str, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down
4 changes: 3 additions & 1 deletion aana/deployments/idefics_2_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.deployments.base_text_generation_deployment import ChatOutput, LLMOutput
from aana.exceptions.runtime import InferenceException
from aana.utils.streamer import async_streamer_adapter
Expand Down Expand Up @@ -88,6 +88,7 @@ async def apply_config(self, config: dict[str, Any]):
self.model_id, **self.model_kwargs
)

@exception_handler
async def chat_stream(
self, dialog: ImageChatDialog, sampling_params: SamplingParams | None = None
) -> AsyncGenerator[LLMOutput, None]:
Expand Down Expand Up @@ -171,6 +172,7 @@ async def chat(

return ChatOutput(message=ChatMessage(content=text, role="assistant"))

@exception_handler
async def chat_batch(
self,
dialogs: list[ImageChatDialog],
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/pyannote_speaker_diarization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SpeakerDiarizationSegment,
)
from aana.core.models.time import TimeInterval
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.speaker import combine_homogeneous_speaker_diarization_segments

Expand Down Expand Up @@ -116,6 +116,7 @@ async def __inference(

return speaker_segments

@exception_handler
async def diarize(
self, audio: Audio, params: PyannoteSpeakerDiarizationParams | None = None
) -> SpeakerDiarizationOutput:
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/sentence_transformer_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import TypedDict

from aana.core.models.base import pydantic_protected_fields
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.batch import BatchProcessor

Expand Down Expand Up @@ -70,6 +70,7 @@ async def apply_config(self, config: dict[str, Any]):

self.model = SentenceTransformer(self.model_id)

@exception_handler
async def embed_batch(self, **kwargs) -> np.ndarray:
"""Embed the given sentences.
Expand Down
3 changes: 2 additions & 1 deletion aana/deployments/vad_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aana.core.models.base import pydantic_protected_fields
from aana.core.models.time import TimeInterval
from aana.core.models.vad import VadParams, VadSegment
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException
from aana.processors.vad import BinarizeVadScores, VoiceActivitySegmentation
from aana.utils.download import download_model
Expand Down Expand Up @@ -211,6 +211,7 @@ async def __inference(self, audio: Audio) -> list[dict]:

return vad_segments

@exception_handler
async def asr_preprocess_vad(
self, audio: Audio, params: VadParams | None = None
) -> VadOutput:
Expand Down
15 changes: 14 additions & 1 deletion aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from aana.core.models.image_chat import ImageChatDialog
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.deployments.base_text_generation_deployment import (
ChatOutput,
LLMBatchOutput,
Expand Down Expand Up @@ -72,6 +72,11 @@ class VLLMConfig(BaseModel):
class VLLMDeployment(BaseDeployment):
"""Deployment to serve large language models using vLLM."""

def __init__(self):
"""Initialize the deployment."""
super().__init__()
self.engine = None

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
Expand Down Expand Up @@ -123,6 +128,13 @@ async def apply_config(self, config: dict[str, Any]):
self.tokenizer = self.engine.engine.tokenizer.tokenizer
self.model_config = await self.engine.get_model_config()

async def check_health(self):
"""Check the health of the deployment."""
if self.engine:
await self.engine.check_health()

await super().check_health()

def apply_chat_template(
self, dialog: ChatDialog | ImageChatDialog
) -> tuple[str | list[int], dict | None]:
Expand Down Expand Up @@ -192,6 +204,7 @@ def replace_image_type(messages: list[dict], images: list[Image]) -> list[dict]:
)
return prompt, mm_data

@exception_handler
async def generate_stream( # noqa: C901
self,
prompt: str | list[int],
Expand Down
4 changes: 3 additions & 1 deletion aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aana.core.models.whisper import (
WhisperParams,
)
from aana.deployments.base_deployment import BaseDeployment
from aana.deployments.base_deployment import BaseDeployment, exception_handler
from aana.exceptions.runtime import InferenceException


Expand Down Expand Up @@ -150,6 +150,7 @@ async def apply_config(self, config: dict[str, Any]):
self.model_size, device=self.device, compute_type=self.compute_type
)

@exception_handler
async def transcribe(
self, audio: Audio, params: WhisperParams | None = None
) -> WhisperOutput:
Expand Down Expand Up @@ -199,6 +200,7 @@ async def transcribe(
transcription=asr_transcription,
)

@exception_handler
async def transcribe_stream(
self, audio: Audio, params: WhisperParams | None = None
) -> AsyncGenerator[WhisperOutput, None]:
Expand Down
Loading

0 comments on commit 580090c

Please sign in to comment.