From c212218408f20e4ecbda7271f1313941e258c6cb Mon Sep 17 00:00:00 2001 From: ykeremy Date: Wed, 27 Mar 2024 21:21:07 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=84=20synced=20local=20'skyvern/'=20wi?= =?UTF-8?q?th=20remote=20'skyvern/'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../forge/sdk/api/llm/api_handler_factory.py | 111 +++++++++++++++++- skyvern/forge/sdk/api/llm/config_registry.py | 14 ++- skyvern/forge/sdk/api/llm/models.py | 29 ++++- 3 files changed, 146 insertions(+), 8 deletions(-) diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 421f4d68d..6a57a07d2 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -1,3 +1,4 @@ +import dataclasses import json from typing import Any @@ -7,8 +8,12 @@ from skyvern.forge import app from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry -from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError -from skyvern.forge.sdk.api.llm.models import LLMAPIHandler +from skyvern.forge.sdk.api.llm.exceptions import ( + DuplicateCustomLLMProviderError, + InvalidLLMConfigError, + LLMProviderError, +) +from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.models import Step @@ -20,10 +25,112 @@ class LLMAPIHandlerFactory: _custom_handlers: dict[str, LLMAPIHandler] = {} + @staticmethod + def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler: + llm_config = LLMConfigRegistry.get_config(llm_key) + if not isinstance(llm_config, LLMRouterConfig): + raise InvalidLLMConfigError(llm_key) + + router = litellm.Router( + model_list=[dataclasses.asdict(model) for model in llm_config.model_list], + redis_host=llm_config.redis_host, + redis_port=llm_config.redis_port, + routing_strategy=llm_config.routing_strategy, + fallbacks=[{llm_config.main_model_group: llm_config.fallback_model_group}] + if llm_config.fallback_model_group + else [], + num_retries=llm_config.num_retries, + retry_after=llm_config.retry_delay_seconds, + set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose, + ) + main_model_group = llm_config.main_model_group + + async def llm_api_handler_with_router_and_fallback( + prompt: str, + step: Step | None = None, + screenshots: list[bytes] | None = None, + parameters: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision. + + Args: + prompt: The prompt to generate completions for. + step: The step object associated with the prompt. + screenshots: The screenshots associated with the prompt. + parameters: Additional parameters to be passed to the LLM router. + + Returns: + The response from the LLM router. + """ + if parameters is None: + parameters = LLMAPIHandlerFactory.get_api_parameters() + + if step: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.LLM_PROMPT, + data=prompt.encode("utf-8"), + ) + for screenshot in screenshots or []: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.SCREENSHOT_LLM, + data=screenshot, + ) + + messages = await llm_messages_builder(prompt, screenshots) + if step: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.LLM_REQUEST, + data=json.dumps( + { + "model": llm_key, + "messages": messages, + **parameters, + } + ).encode("utf-8"), + ) + try: + response = await router.acompletion(model=main_model_group, messages=messages, **parameters) + except openai.OpenAIError as e: + raise LLMProviderError(llm_key) from e + except Exception as e: + LOG.exception("LLM request failed unexpectedly", llm_key=llm_key) + raise LLMProviderError(llm_key) from e + + if step: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.LLM_RESPONSE, + data=response.model_dump_json(indent=2).encode("utf-8"), + ) + llm_cost = litellm.completion_cost(completion_response=response) + await app.DATABASE.update_step( + task_id=step.task_id, + step_id=step.step_id, + organization_id=step.organization_id, + incremental_cost=llm_cost, + ) + parsed_response = parse_api_response(response) + if step: + await app.ARTIFACT_MANAGER.create_artifact( + step=step, + artifact_type=ArtifactType.LLM_RESPONSE_PARSED, + data=json.dumps(parsed_response, indent=2).encode("utf-8"), + ) + return parsed_response + + return llm_api_handler_with_router_and_fallback + @staticmethod def get_llm_api_handler(llm_key: str) -> LLMAPIHandler: llm_config = LLMConfigRegistry.get_config(llm_key) + if LLMConfigRegistry.is_router_config(llm_key): + return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key) + async def llm_api_handler( prompt: str, step: Step | None = None, diff --git a/skyvern/forge/sdk/api/llm/config_registry.py b/skyvern/forge/sdk/api/llm/config_registry.py index 7e966ce37..c43c4968e 100644 --- a/skyvern/forge/sdk/api/llm/config_registry.py +++ b/skyvern/forge/sdk/api/llm/config_registry.py @@ -6,23 +6,27 @@ MissingLLMProviderEnvVarsError, NoProviderEnabledError, ) -from skyvern.forge.sdk.api.llm.models import LLMConfig +from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig from skyvern.forge.sdk.settings_manager import SettingsManager LOG = structlog.get_logger() class LLMConfigRegistry: - _configs: dict[str, LLMConfig] = {} + _configs: dict[str, LLMRouterConfig | LLMConfig] = {} @staticmethod - def validate_config(llm_key: str, config: LLMConfig) -> None: + def is_router_config(llm_key: str) -> bool: + return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig) + + @staticmethod + def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None: missing_env_vars = config.get_missing_env_vars() if missing_env_vars: raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars) @classmethod - def register_config(cls, llm_key: str, config: LLMConfig) -> None: + def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None: if llm_key in cls._configs: raise DuplicateLLMConfigError(llm_key) @@ -32,7 +36,7 @@ def register_config(cls, llm_key: str, config: LLMConfig) -> None: cls._configs[llm_key] = config @classmethod - def get_config(cls, llm_key: str) -> LLMConfig: + def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig: if llm_key not in cls._configs: raise InvalidLLMConfigError(llm_key) diff --git a/skyvern/forge/sdk/api/llm/models.py b/skyvern/forge/sdk/api/llm/models.py index 3fafa90fa..2fa2b18f6 100644 --- a/skyvern/forge/sdk/api/llm/models.py +++ b/skyvern/forge/sdk/api/llm/models.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Awaitable, Protocol +from typing import Any, Awaitable, Literal, Protocol from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.settings_manager import SettingsManager @@ -21,6 +21,33 @@ def get_missing_env_vars(self) -> list[str]: return missing_env_vars +@dataclass(frozen=True) +class LLMRouterModelConfig: + model_name: str + # https://litellm.vercel.app/docs/routing + litellm_params: dict[str, Any] + tpm: int | None = None + rpm: int | None = None + + +@dataclass(frozen=True) +class LLMRouterConfig(LLMConfig): + model_list: list[LLMRouterModelConfig] + redis_host: str + redis_port: int + main_model_group: str + fallback_model_group: str | None = None + routing_strategy: Literal[ + "simple-shuffle", + "least-busy", + "usage-based-routing", + "latency-based-routing", + ] = "usage-based-routing" + num_retries: int = 2 + retry_delay_seconds: int = 15 + set_verbose: bool = True + + class LLMAPIHandler(Protocol): def __call__( self,