diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 6e437c7..5c9a683 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -22,7 +22,7 @@ result, ) from .result import ResultData -from .settings import ModelSettings, merge_model_settings +from .settings import ExecutionLimitSettings, ModelSettings, merge_model_settings from .tools import ( AgentDeps, RunContext, @@ -191,6 +191,7 @@ async def run( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, + execution_limit_settings: ExecutionLimitSettings | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -211,8 +212,9 @@ async def run( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. model_settings: Optional settings to use for this model's request. + execution_limit_settings: Optional settings to use in order to limit model request or cost (token usage). + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. @@ -238,8 +240,8 @@ async def run( tool.current_retry = 0 cost = result.Cost() - model_settings = merge_model_settings(self.model_settings, model_settings) + execution_limit_settings = execution_limit_settings or ExecutionLimitSettings(request_limit=50) run_step = 0 while True: @@ -254,6 +256,8 @@ async def run( messages.append(model_response) cost += request_cost + # TODO: is this the right location? Should we move this earlier in the logic? + execution_limit_settings.increment(request_cost) with _logfire.span('handle model response', run_step=run_step) as handle_span: final_result, tool_responses = await self._handle_model_response(model_response, deps, messages) @@ -284,6 +288,7 @@ def run_sync( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, + execution_limit_settings: ExecutionLimitSettings | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -308,8 +313,9 @@ async def main(): message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. model_settings: Optional settings to use for this model's request. + execution_limit_settings: Optional settings to use in order to limit model request or cost (token usage). + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. @@ -322,8 +328,9 @@ async def main(): message_history=message_history, model=model, deps=deps, - infer_name=False, + execution_limit_settings=execution_limit_settings, model_settings=model_settings, + infer_name=False, ) ) @@ -336,6 +343,7 @@ async def run_stream( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, + execution_limit_settings: ExecutionLimitSettings | None = None, infer_name: bool = True, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -357,8 +365,9 @@ async def main(): message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. model_settings: Optional settings to use for this model's request. + execution_limit_settings: Optional settings to use in order to limit model request or cost (token usage). + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. @@ -387,6 +396,7 @@ async def main(): cost = result.Cost() model_settings = merge_model_settings(self.model_settings, model_settings) + execution_limit_settings = execution_limit_settings or ExecutionLimitSettings(request_limit=50) run_step = 0 while True: @@ -456,7 +466,9 @@ async def on_complete(): tool_responses_str = ' '.join(r.part_kind for r in tool_responses) handle_span.message = f'handle model response -> {tool_responses_str}' # the model_response should have been fully streamed by now, we can add it's cost - cost += model_response.cost() + model_response_cost = model_response.cost() + execution_limit_settings.increment(model_response_cost) + cost += model_response_cost @contextmanager def override( diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index d3e2d42..94ea6e5 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -1,8 +1,16 @@ from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING + from httpx import Timeout from typing_extensions import TypedDict +from .exceptions import UnexpectedModelBehavior + +if TYPE_CHECKING: + from .result import Cost + class ModelSettings(TypedDict, total=False): """Settings to configure an LLM. @@ -70,3 +78,35 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | return base | overrides else: return base or overrides + + +@dataclass +class ExecutionLimitSettings: + """Settings to configure an agent run.""" + + request_limit: int | None = None + request_tokens_limit: int | None = None + response_tokens_limit: int | None = None + total_tokens_limit: int | None = None + + _request_count: int = 0 + _request_tokens_count: int = 0 + _response_tokens_count: int = 0 + _total_tokens_count: int = 0 + + def increment(self, cost: Cost) -> None: + self._request_count += 1 + self._check_limit(self.request_limit, self._request_count, 'request count') + + self._request_tokens_count += cost.request_tokens or 0 + self._check_limit(self.request_tokens_limit, self._request_tokens_count, 'request tokens count') + + self._response_tokens_count += cost.response_tokens or 0 + self._check_limit(self.response_tokens_limit, self._response_tokens_count, 'response tokens count') + + self._total_tokens_count += cost.total_tokens or 0 + self._check_limit(self.total_tokens_limit, self._total_tokens_count, 'total tokens count') + + def _check_limit(self, limit: int | None, count: int, limit_name: str) -> None: + if limit and limit < count: + raise UnexpectedModelBehavior(f'Exceeded {limit_name} limit of {limit} by {count - limit}')