Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Support execution limits in run_ functions #374

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
result,
)
from .result import ResultData
from .settings import ModelSettings, merge_model_settings
from .settings import ExecutionLimitSettings, ModelSettings, merge_model_settings
from .tools import (
AgentDeps,
RunContext,
Expand Down Expand Up @@ -191,6 +191,7 @@ async def run(
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
execution_limit_settings: ExecutionLimitSettings | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
execution_limit_settings: ExecutionLimitSettings | None = None,
execution_limits: ExecutionLimits | None = None,

this would be my preference

infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Expand All @@ -211,8 +212,9 @@ async def run(
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
model_settings: Optional settings to use for this model's request.
execution_limit_settings: Optional settings to use in order to limit model request or cost (token usage).
infer_name: Whether to try to infer the agent name from the call frame if it's not set.

Returns:
The result of the run.
Expand All @@ -238,8 +240,8 @@ async def run(
tool.current_retry = 0

cost = result.Cost()

model_settings = merge_model_settings(self.model_settings, model_settings)
execution_limit_settings = execution_limit_settings or ExecutionLimitSettings(request_limit=50)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this where we want to set the default?


run_step = 0
while True:
Expand All @@ -254,6 +256,8 @@ async def run(

messages.append(model_response)
cost += request_cost
# TODO: is this the right location? Should we move this earlier in the logic?
execution_limit_settings.increment(request_cost)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally would prefer if we added a request_count field to the Cost type, and then just did execution_limit_settings.validate(cost) here (rather than incrementing both cost and the limits).

I'd also prefer we rename Cost to Usage or similar, since that's really what it's representing now, and would make it feel less weird to add the request_count field. But even if we don't rename it like that, I think it's reasonable to add request_count: int (or requests: int) as a field on the type currently known as Cost


with _logfire.span('handle model response', run_step=run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
Expand Down Expand Up @@ -284,6 +288,7 @@ def run_sync(
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
execution_limit_settings: ExecutionLimitSettings | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
Expand All @@ -308,8 +313,9 @@ async def main():
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
model_settings: Optional settings to use for this model's request.
execution_limit_settings: Optional settings to use in order to limit model request or cost (token usage).
infer_name: Whether to try to infer the agent name from the call frame if it's not set.

Returns:
The result of the run.
Expand All @@ -322,8 +328,9 @@ async def main():
message_history=message_history,
model=model,
deps=deps,
infer_name=False,
execution_limit_settings=execution_limit_settings,
model_settings=model_settings,
infer_name=False,
)
)

Expand All @@ -336,6 +343,7 @@ async def run_stream(
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
execution_limit_settings: ExecutionLimitSettings | None = None,
infer_name: bool = True,
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
"""Run the agent with a user prompt in async mode, returning a streamed response.
Expand All @@ -357,8 +365,9 @@ async def main():
message_history: History of the conversation so far.
model: Optional model to use for this run, required if `model` was not set when creating the agent.
deps: Optional dependencies to use for this run.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
model_settings: Optional settings to use for this model's request.
execution_limit_settings: Optional settings to use in order to limit model request or cost (token usage).
infer_name: Whether to try to infer the agent name from the call frame if it's not set.

Returns:
The result of the run.
Expand Down Expand Up @@ -387,6 +396,7 @@ async def main():

cost = result.Cost()
model_settings = merge_model_settings(self.model_settings, model_settings)
execution_limit_settings = execution_limit_settings or ExecutionLimitSettings(request_limit=50)

run_step = 0
while True:
Expand Down Expand Up @@ -456,7 +466,9 @@ async def on_complete():
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
handle_span.message = f'handle model response -> {tool_responses_str}'
# the model_response should have been fully streamed by now, we can add it's cost
cost += model_response.cost()
model_response_cost = model_response.cost()
execution_limit_settings.increment(model_response_cost)
cost += model_response_cost

@contextmanager
def override(
Expand Down
40 changes: 40 additions & 0 deletions pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from httpx import Timeout
from typing_extensions import TypedDict

from .exceptions import UnexpectedModelBehavior

if TYPE_CHECKING:
from .result import Cost


class ModelSettings(TypedDict, total=False):
"""Settings to configure an LLM.
Expand Down Expand Up @@ -70,3 +78,35 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings |
return base | overrides
else:
return base or overrides


@dataclass
class ExecutionLimitSettings:
"""Settings to configure an agent run."""

request_limit: int | None = None
request_tokens_limit: int | None = None
response_tokens_limit: int | None = None
total_tokens_limit: int | None = None

_request_count: int = 0
_request_tokens_count: int = 0
_response_tokens_count: int = 0
_total_tokens_count: int = 0
Comment on lines +92 to +95
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be public?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say yes if we want to also include this structure in RunContext...

Copy link
Contributor

@dmontagu dmontagu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the idea that the settings object also holds state, it feels to me like there should be a separate object for tracking state, and we can check the state against the settings. If I were a user I'd be inclined to reuse an instance of ExecutionLimitSettings which obviously will cause issues.

I would imagine we make a private type _UsageState or similar (which holds all the fields you are talking about here), and have one of ExecutionLimits and _UsageState have a method that accepts the other and raises an error if appropriate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to put the usage state on the runcontext we can make it public, but I feel like we can do that later/separately. I'll note that I could imagine Samuel disagreeing with all this, and I wouldn't find that unreasonable.


def increment(self, cost: Cost) -> None:
self._request_count += 1
self._check_limit(self.request_limit, self._request_count, 'request count')

self._request_tokens_count += cost.request_tokens or 0
self._check_limit(self.request_tokens_limit, self._request_tokens_count, 'request tokens count')

self._response_tokens_count += cost.response_tokens or 0
self._check_limit(self.response_tokens_limit, self._response_tokens_count, 'response tokens count')

self._total_tokens_count += cost.total_tokens or 0
self._check_limit(self.total_tokens_limit, self._total_tokens_count, 'total tokens count')

def _check_limit(self, limit: int | None, count: int, limit_name: str) -> None:
if limit and limit < count:
raise UnexpectedModelBehavior(f'Exceeded {limit_name} limit of {limit} by {count - limit}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this deserves its own exception, and probably one that doesn't inherit from UnexpectedModelBehavior (as this is more or less expected behavior)

Loading