diff --git a/agenta-cli/agenta/client/backend/types/llm_tokens.py b/agenta-cli/agenta/client/backend/types/llm_tokens.py index 0c5a71755b..7e9fd8358b 100644 --- a/agenta-cli/agenta/client/backend/types/llm_tokens.py +++ b/agenta-cli/agenta/client/backend/types/llm_tokens.py @@ -33,6 +33,6 @@ def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: return super().dict(**kwargs_with_defaults) class Config: - frozen = True + frozen = False smart_union = True json_encoders = {dt.datetime: serialize_datetime} diff --git a/agenta-cli/agenta/sdk/tracing/llm_tracing.py b/agenta-cli/agenta/sdk/tracing/llm_tracing.py index d1309c3015..23a806569a 100644 --- a/agenta-cli/agenta/sdk/tracing/llm_tracing.py +++ b/agenta-cli/agenta/sdk/tracing/llm_tracing.py @@ -7,7 +7,11 @@ from agenta.sdk.tracing.tasks_manager import TaskQueue from agenta.client.backend.client import AsyncAgentaApi from agenta.client.backend.client import AsyncObservabilityClient -from agenta.client.backend.types.create_span import CreateSpan, SpanStatusCode +from agenta.client.backend.types.create_span import ( + CreateSpan, + LlmTokens, + SpanStatusCode, +) from bson.objectid import ObjectId @@ -173,6 +177,24 @@ def start_span( def update_span_status(self, span: CreateSpan, value: str): span.status = value + def _update_span_cost(self, span: CreateSpan, cost: Optional[float]): + if cost is not None and isinstance(cost, float): + if span.cost is None: + span.cost = cost + else: + span.cost += cost + + def _update_span_tokens(self, span: CreateSpan, tokens: Optional[dict]): + if isinstance(tokens, LlmTokens): + tokens = tokens.dict() + if tokens is not None and isinstance(tokens, dict): + if span.tokens is None: + span.tokens = LlmTokens(**tokens) + else: + span.tokens.prompt_tokens += tokens["prompt_tokens"] + span.tokens.completion_tokens += tokens["completion_tokens"] + span.tokens.total_tokens += tokens["total_tokens"] + def end_span(self, outputs: Dict[str, Any]): """ Ends the active span, if it is a parent span, ends the trace too. @@ -196,8 +218,12 @@ def end_span(self, outputs: Dict[str, Any]): self.active_span.end_time = datetime.now(timezone.utc) self.active_span.outputs = [outputs.get("message", "")] - self.active_span.cost = outputs.get("cost", None) - self.active_span.tokens = outputs.get("usage", None) + if self.active_span.spankind in [ + "LLM", + "RETRIEVER", + ]: # TODO: Remove this whole part. Setting the cost should be done through set_span_attribute + self._update_span_cost(self.active_span, outputs.get("cost", None)) + self._update_span_tokens(self.active_span, outputs.get("usage", None)) # Push span to list of recorded spans self.pending_spans.append(self.active_span) @@ -213,7 +239,10 @@ def end_span(self, outputs: Dict[str, Any]): self.end_trace(parent_span=self.active_span) else: - self.active_span = self.span_dict.get(active_span_parent_id) + parent_span = self.span_dict[active_span_parent_id] + self._update_span_cost(parent_span, self.active_span.cost) + self._update_span_tokens(parent_span, self.active_span.tokens) + self.active_span = parent_span def record_exception_and_end_trace(self, span_parent_id: str): """