Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate cost and tokens #1740

Merged
merged 9 commits into from
Jun 10, 2024
2 changes: 1 addition & 1 deletion agenta-cli/agenta/client/backend/types/llm_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]:
return super().dict(**kwargs_with_defaults)

class Config:
frozen = True
frozen = False
smart_union = True
json_encoders = {dt.datetime: serialize_datetime}
37 changes: 33 additions & 4 deletions agenta-cli/agenta/sdk/tracing/llm_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from agenta.sdk.tracing.tasks_manager import TaskQueue
from agenta.client.backend.client import AsyncAgentaApi
from agenta.client.backend.client import AsyncObservabilityClient
from agenta.client.backend.types.create_span import CreateSpan, SpanStatusCode
from agenta.client.backend.types.create_span import (
CreateSpan,
LlmTokens,
SpanStatusCode,
)

from bson.objectid import ObjectId

Expand Down Expand Up @@ -173,6 +177,24 @@ def start_span(
def update_span_status(self, span: CreateSpan, value: str):
span.status = value

def _update_span_cost(self, span: CreateSpan, cost: Optional[float]):
if cost is not None and isinstance(cost, float):
if span.cost is None:
span.cost = cost
else:
span.cost += cost

def _update_span_tokens(self, span: CreateSpan, tokens: Optional[dict]):
if isinstance(tokens, LlmTokens):
tokens = tokens.dict()
if tokens is not None and isinstance(tokens, dict):
if span.tokens is None:
span.tokens = LlmTokens(**tokens)
else:
span.tokens.prompt_tokens += tokens["prompt_tokens"]
span.tokens.completion_tokens += tokens["completion_tokens"]
span.tokens.total_tokens += tokens["total_tokens"]

def end_span(self, outputs: Dict[str, Any]):
"""
Ends the active span, if it is a parent span, ends the trace too.
Expand All @@ -196,8 +218,12 @@ def end_span(self, outputs: Dict[str, Any]):

self.active_span.end_time = datetime.now(timezone.utc)
self.active_span.outputs = [outputs.get("message", "")]
self.active_span.cost = outputs.get("cost", None)
self.active_span.tokens = outputs.get("usage", None)
if self.active_span.spankind in [
"LLM",
"RETRIEVER",
]: # TODO: Remove this whole part. Setting the cost should be done through set_span_attribute
self._update_span_cost(self.active_span, outputs.get("cost", None))
self._update_span_tokens(self.active_span, outputs.get("usage", None))

# Push span to list of recorded spans
self.pending_spans.append(self.active_span)
Expand All @@ -213,7 +239,10 @@ def end_span(self, outputs: Dict[str, Any]):
self.end_trace(parent_span=self.active_span)

else:
self.active_span = self.span_dict.get(active_span_parent_id)
parent_span = self.span_dict[active_span_parent_id]
self._update_span_cost(parent_span, self.active_span.cost)
self._update_span_tokens(parent_span, self.active_span.tokens)
self.active_span = parent_span

def record_exception_and_end_trace(self, span_parent_id: str):
"""
Expand Down
Loading