Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement LLMRouter #127

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 109 additions & 2 deletions skyvern/forge/sdk/api/llm/api_handler_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import json
from typing import Any

Expand All @@ -7,8 +8,12 @@

from skyvern.forge import app
from skyvern.forge.sdk.api.llm.config_registry import LLMConfigRegistry
from skyvern.forge.sdk.api.llm.exceptions import DuplicateCustomLLMProviderError, LLMProviderError
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler
from skyvern.forge.sdk.api.llm.exceptions import (
DuplicateCustomLLMProviderError,
InvalidLLMConfigError,
LLMProviderError,
)
from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMRouterConfig
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
Expand All @@ -20,10 +25,112 @@
class LLMAPIHandlerFactory:
_custom_handlers: dict[str, LLMAPIHandler] = {}

@staticmethod
def get_llm_api_handler_with_router(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)
if not isinstance(llm_config, LLMRouterConfig):
raise InvalidLLMConfigError(llm_key)

router = litellm.Router(
model_list=[dataclasses.asdict(model) for model in llm_config.model_list],
redis_host=llm_config.redis_host,
redis_port=llm_config.redis_port,
routing_strategy=llm_config.routing_strategy,
fallbacks=[{llm_config.main_model_group: llm_config.fallback_model_group}]
if llm_config.fallback_model_group
else [],
num_retries=llm_config.num_retries,
retry_after=llm_config.retry_delay_seconds,
set_verbose=False if SettingsManager.get_settings().is_cloud_environment() else llm_config.set_verbose,
)
main_model_group = llm_config.main_model_group

async def llm_api_handler_with_router_and_fallback(
prompt: str,
step: Step | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Custom LLM API handler that utilizes the LiteLLM router and fallbacks to OpenAI GPT-4 Vision.

Args:
prompt: The prompt to generate completions for.
step: The step object associated with the prompt.
screenshots: The screenshots associated with the prompt.
parameters: Additional parameters to be passed to the LLM router.

Returns:
The response from the LLM router.
"""
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters()

if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)

messages = await llm_messages_builder(prompt, screenshots)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
except openai.OpenAIError as e:
raise LLMProviderError(llm_key) from e
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e

if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
await app.DATABASE.update_step(
task_id=step.task_id,
step_id=step.step_id,
organization_id=step.organization_id,
incremental_cost=llm_cost,
)
parsed_response = parse_api_response(response)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
return parsed_response

return llm_api_handler_with_router_and_fallback

@staticmethod
def get_llm_api_handler(llm_key: str) -> LLMAPIHandler:
llm_config = LLMConfigRegistry.get_config(llm_key)

if LLMConfigRegistry.is_router_config(llm_key):
return LLMAPIHandlerFactory.get_llm_api_handler_with_router(llm_key)

async def llm_api_handler(
prompt: str,
step: Step | None = None,
Expand Down
14 changes: 9 additions & 5 deletions skyvern/forge/sdk/api/llm/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@
MissingLLMProviderEnvVarsError,
NoProviderEnabledError,
)
from skyvern.forge.sdk.api.llm.models import LLMConfig
from skyvern.forge.sdk.api.llm.models import LLMConfig, LLMRouterConfig
from skyvern.forge.sdk.settings_manager import SettingsManager

LOG = structlog.get_logger()


class LLMConfigRegistry:
_configs: dict[str, LLMConfig] = {}
_configs: dict[str, LLMRouterConfig | LLMConfig] = {}

@staticmethod
def validate_config(llm_key: str, config: LLMConfig) -> None:
def is_router_config(llm_key: str) -> bool:
return isinstance(LLMConfigRegistry.get_config(llm_key), LLMRouterConfig)

@staticmethod
def validate_config(llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
missing_env_vars = config.get_missing_env_vars()
if missing_env_vars:
raise MissingLLMProviderEnvVarsError(llm_key, missing_env_vars)

@classmethod
def register_config(cls, llm_key: str, config: LLMConfig) -> None:
def register_config(cls, llm_key: str, config: LLMRouterConfig | LLMConfig) -> None:
if llm_key in cls._configs:
raise DuplicateLLMConfigError(llm_key)

Expand All @@ -32,7 +36,7 @@ def register_config(cls, llm_key: str, config: LLMConfig) -> None:
cls._configs[llm_key] = config

@classmethod
def get_config(cls, llm_key: str) -> LLMConfig:
def get_config(cls, llm_key: str) -> LLMRouterConfig | LLMConfig:
if llm_key not in cls._configs:
raise InvalidLLMConfigError(llm_key)

Expand Down
29 changes: 28 additions & 1 deletion skyvern/forge/sdk/api/llm/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Awaitable, Protocol
from typing import Any, Awaitable, Literal, Protocol

from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.settings_manager import SettingsManager
Expand All @@ -21,6 +21,33 @@ def get_missing_env_vars(self) -> list[str]:
return missing_env_vars


@dataclass(frozen=True)
class LLMRouterModelConfig:
model_name: str
# https://litellm.vercel.app/docs/routing
litellm_params: dict[str, Any]
tpm: int | None = None
rpm: int | None = None


@dataclass(frozen=True)
class LLMRouterConfig(LLMConfig):
model_list: list[LLMRouterModelConfig]
redis_host: str
redis_port: int
main_model_group: str
fallback_model_group: str | None = None
routing_strategy: Literal[
"simple-shuffle",
"least-busy",
"usage-based-routing",
"latency-based-routing",
] = "usage-based-routing"
num_retries: int = 2
retry_delay_seconds: int = 15
set_verbose: bool = True


class LLMAPIHandler(Protocol):
def __call__(
self,
Expand Down
Loading