Skip to content

Commit

Permalink
Add test for error case and rework things a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
xrmx committed Jan 15, 2025
1 parent 1a3fb9c commit c250004
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from opentelemetry.instrumentation.botocore.extensions.types import (
_AttributeMapT,
_AwsSdkExtension,
_BotoClientErrorT,
)
from opentelemetry.semconv._incubating.attributes.error_attributes import (
ERROR_TYPE,
)
from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import (
GEN_AI_OPERATION_NAME,
Expand All @@ -40,6 +44,7 @@
GenAiSystemValues,
)
from opentelemetry.trace.span import Span
from opentelemetry.trace.status import Status, StatusCode

_logger = logging.getLogger(__name__)

Expand All @@ -56,25 +61,17 @@ class _BedrockRuntimeExtension(_AwsSdkExtension):
_HANDLED_OPERATIONS = {"Converse"}

def extract_attributes(self, attributes: _AttributeMapT):
attributes[GEN_AI_SYSTEM] = GenAiSystemValues.AWS_BEDROCK.value

if self._call_context.operation not in self._HANDLED_OPERATIONS:
return

attributes[GEN_AI_SYSTEM] = GenAiSystemValues.AWS_BEDROCK.value

model_id = self._call_context.params.get(_MODEL_ID_KEY)
if model_id:
attributes[GEN_AI_REQUEST_MODEL] = model_id

# FIXME: add other model patterns
text_model_patterns = [
"amazon.titan-text",
"anthropic.claude",
"meta.llama",
]
if any(pattern in model_id for pattern in text_model_patterns):
attributes[GEN_AI_OPERATION_NAME] = (
GenAiOperationNameValues.CHAT.value
)
attributes[GEN_AI_OPERATION_NAME] = (
GenAiOperationNameValues.CHAT.value
)

if inference_config := self._call_context.params.get(
"inferenceConfig"
Expand Down Expand Up @@ -122,9 +119,7 @@ def on_success(self, span: Span, result: dict[str, Any]):
if self._call_context.operation not in self._HANDLED_OPERATIONS:
return

model_id = self._call_context.params.get(_MODEL_ID_KEY)

if not model_id:
if not span.is_recording():
return

if usage := result.get("usage"):
Expand All @@ -144,3 +139,11 @@ def on_success(self, span: Span, result: dict[str, Any]):
GEN_AI_RESPONSE_FINISH_REASONS,
[stop_reason],
)

def on_error(self, span: Span, exception: _BotoClientErrorT):
if self._call_context.operation not in self._HANDLED_OPERATIONS:
return

span.set_status(Status(StatusCode.ERROR, str(exception)))
if span.is_recording():
span.set_attribute(ERROR_TYPE, type(exception).__qualname__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.semconv._incubating.attributes import (
gen_ai_attributes as GenAIAttributes,
)


def assert_completion_attributes(
span: ReadableSpan,
request_model: str,
response: dict[str, Any] | None,
operation_name: str = "chat",
request_top_p: int | None = None,
request_temperature: int | None = None,
request_max_tokens: int | None = None,
request_stop_sequences: list[str] | None = None,
):
if usage := (response and response.get("usage")):
input_tokens = usage["inputTokens"]
output_tokens = usage["outputTokens"]
else:
input_tokens, output_tokens = None, None

if response:
finish_reason = (response["stopReason"],)
else:
finish_reason = None

return assert_all_attributes(
span,
request_model,
input_tokens,
output_tokens,
finish_reason,
operation_name,
request_top_p,
request_temperature,
request_max_tokens,
tuple(request_stop_sequences)
if request_stop_sequences is not None
else request_stop_sequences,
)


def assert_equal_or_not_present(value, attribute_name, span):
if value:
assert value == span.attributes[attribute_name]
else:
assert attribute_name not in span.attributes


def assert_all_attributes(
span: ReadableSpan,
request_model: str,
input_tokens: int | None = None,
output_tokens: int | None = None,
finish_reason: tuple[str] | None = None,
operation_name: str = "chat",
request_top_p: int | None = None,
request_temperature: int | None = None,
request_max_tokens: int | None = None,
request_stop_sequences: tuple[str] | None = None,
):
assert span.name == f"{operation_name} {request_model}"
assert (
operation_name
== span.attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]
)
assert (
GenAIAttributes.GenAiSystemValues.AWS_BEDROCK.value
== span.attributes[GenAIAttributes.GEN_AI_SYSTEM]
)
assert (
request_model == span.attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]
)

assert_equal_or_not_present(
input_tokens, GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS, span
)
assert_equal_or_not_present(
output_tokens, GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, span
)
assert_equal_or_not_present(
finish_reason, GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS, span
)
assert_equal_or_not_present(
request_top_p, GenAIAttributes.GEN_AI_REQUEST_TOP_P, span
)
assert_equal_or_not_present(
request_temperature, GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE, span
)
assert_equal_or_not_present(
request_max_tokens, GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS, span
)
assert_equal_or_not_present(
request_stop_sequences,
GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES,
span,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
interactions:
- request:
body: |-
{
"messages": [
{
"role": "user",
"content": [
{
"text": "Say this is a test"
}
]
}
]
}
headers:
Content-Length:
- '77'
Content-Type:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
User-Agent:
- !!binary |
Qm90bzMvMS4zNS41NiBtZC9Cb3RvY29yZSMxLjM1LjU2IHVhLzIuMCBvcy9saW51eCM2LjEuMC0x
MDM0LW9lbSBtZC9hcmNoI3g4Nl82NCBsYW5nL3B5dGhvbiMzLjEwLjEyIG1kL3B5aW1wbCNDUHl0
aG9uIGNmZy9yZXRyeS1tb2RlI2xlZ2FjeSBCb3RvY29yZS8xLjM1LjU2
X-Amz-Date:
- !!binary |
MjAyNTAxMTVUMTEwMTQ3Wg==
X-Amz-Security-Token:
- test_aws_security_token
X-Amzn-Trace-Id:
- !!binary |
Um9vdD0xLWIzM2JhNTkxLTdkYmQ0ZDZmYTBmZTdmYzc2MTExOThmNztQYXJlbnQ9NzRmNmQ1NTEz
MzkzMzUxNTtTYW1wbGVkPTE=
amz-sdk-invocation-id:
- !!binary |
NTQ5MmQ0NTktNzhkNi00ZWY4LTlmMDMtZTA5ODhkZGRiZDI5
amz-sdk-request:
- !!binary |
YXR0ZW1wdD0x
authorization:
- Bearer test_aws_authorization
method: POST
uri: https://bedrock-runtime.eu-central-1.amazonaws.com/model/does-not-exist/converse
response:
body:
string: |-
{
"message": "The provided model identifier is invalid."
}
headers:
Connection:
- keep-alive
Content-Length:
- '55'
Content-Type:
- application/json
Date:
- Wed, 15 Jan 2025 11:01:47 GMT
Set-Cookie: test_set_cookie
x-amzn-ErrorType:
- ValidationException:http://internal.amazon.com/coral/com.amazon.bedrock/
x-amzn-RequestId:
- d425bf99-8a4e-4d83-8d77-a48410dd82b2
status:
code: 400
message: Bad Request
version: 1
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

from __future__ import annotations

from typing import Any

import boto3
import pytest

from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.semconv._incubating.attributes import (
gen_ai_attributes as GenAIAttributes,
from opentelemetry.semconv._incubating.attributes.error_attributes import (
ERROR_TYPE,
)
from opentelemetry.trace.status import StatusCode

from .bedrock_utils import assert_completion_attributes

BOTO3_VERSION = tuple(int(x) for x in boto3.__version__.split("."))

Expand Down Expand Up @@ -68,82 +68,35 @@ def test_converse_with_content(
assert len(logs) == 0


def assert_completion_attributes(
span: ReadableSpan,
request_model: str,
response: dict[str, Any],
operation_name: str = "chat",
request_top_p: int | None = None,
request_temperature: int | None = None,
request_max_tokens: int | None = None,
request_stop_sequences: list[str] | None = None,
@pytest.mark.skipif(
BOTO3_VERSION < (1, 35, 56), reason="Converse API not available"
)
@pytest.mark.vcr()
def test_converse_with_invalid_model(
span_exporter,
log_exporter,
bedrock_runtime_client,
instrument_with_content,
):
return assert_all_attributes(
span,
request_model,
response["usage"]["inputTokens"],
response["usage"]["outputTokens"],
(response["stopReason"],),
operation_name,
request_top_p,
request_temperature,
request_max_tokens,
tuple(request_stop_sequences),
)

messages = [{"role": "user", "content": [{"text": "Say this is a test"}]}]

def assert_equal_or_not_present(value, attribute_name, span):
if value:
assert value == span.attributes[attribute_name]
else:
assert attribute_name not in span.attributes


def assert_all_attributes(
span: ReadableSpan,
request_model: str,
input_tokens: int | None = None,
output_tokens: int | None = None,
finish_reason: tuple[str] | None = None,
operation_name: str = "chat",
request_top_p: int | None = None,
request_temperature: int | None = None,
request_max_tokens: int | None = None,
request_stop_sequences: tuple[str] | None = None,
):
assert span.name == f"{operation_name} {request_model}"
assert (
operation_name
== span.attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]
)
assert (
GenAIAttributes.GenAiSystemValues.AWS_BEDROCK.value
== span.attributes[GenAIAttributes.GEN_AI_SYSTEM]
)
assert (
request_model == span.attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]
)
llm_model_value = "does-not-exist"
with pytest.raises(bedrock_runtime_client.exceptions.ValidationException):
bedrock_runtime_client.converse(
messages=messages,
modelId=llm_model_value,
)

assert_equal_or_not_present(
input_tokens, GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS, span
)
assert_equal_or_not_present(
output_tokens, GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, span
)
assert_equal_or_not_present(
finish_reason, GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS, span
)
assert_equal_or_not_present(
request_top_p, GenAIAttributes.GEN_AI_REQUEST_TOP_P, span
)
assert_equal_or_not_present(
request_temperature, GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE, span
)
assert_equal_or_not_present(
request_max_tokens, GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS, span
)
assert_equal_or_not_present(
request_stop_sequences,
GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES,
(span,) = span_exporter.get_finished_spans()
assert_completion_attributes(
span,
llm_model_value,
None,
"chat",
)

assert span.status.status_code == StatusCode.ERROR
assert span.attributes[ERROR_TYPE] == "ValidationException"

logs = log_exporter.get_finished_logs()
assert len(logs) == 0

0 comments on commit c250004

Please sign in to comment.