Skip to content

Commit

Permalink
Enable global propagator for AWS instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ocelotl committed Jul 24, 2024
1 parent e799a74 commit b1d44c1
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 85 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `opentelemetry-instrumentation-flask` Add `http.route` and `http.target` to metric attributes
([#2621](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2621))
- `opentelemetry-instrumentation-aws-lambda` Enable global propagator for AWS instrumentation
([#2708](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2708))
- `opentelemetry-instrumentation-sklearn` Deprecated the sklearn instrumentation
([#2708](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2708))
- `opentelemetry-instrumentation-pyramid` Record exceptions raised when serving a request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ def custom_event_context_extractor(lambda_event):
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.metrics import MeterProvider, get_meter_provider
from opentelemetry.propagate import get_global_textmap
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
AwsXRayPropagator,
)
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import (
Expand All @@ -96,7 +92,6 @@ def custom_event_context_extractor(lambda_event):
get_tracer,
get_tracer_provider,
)
from opentelemetry.trace.propagation import get_current_span
from opentelemetry.trace.status import Status, StatusCode

logger = logging.getLogger(__name__)
Expand All @@ -107,9 +102,6 @@ def custom_event_context_extractor(lambda_event):
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT = (
"OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT"
)
OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION = (
"OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION"
)


def _default_event_context_extractor(lambda_event: Any) -> Context:
Expand Down Expand Up @@ -145,7 +137,6 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:
def _determine_parent_context(
lambda_event: Any,
event_context_extractor: Callable[[Any], Context],
disable_aws_context_propagation: bool = False,
) -> Context:
"""Determine the parent context for the current Lambda invocation.
Expand All @@ -159,36 +150,14 @@ def _determine_parent_context(
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
disable_aws_context_propagation: By default, this instrumentation
will try to read the context from the `_X_AMZN_TRACE_ID` environment
variable set by Lambda, set this to `True` to disable this behavior.
Returns:
A Context with configuration found in the carrier.
"""
parent_context = None

if not disable_aws_context_propagation:
xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)

if xray_env_var:
parent_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)

if (
parent_context
and get_current_span(parent_context)
.get_span_context()
.trace_flags.sampled
):
return parent_context
if event_context_extractor is None:
return _default_event_context_extractor(lambda_event)

if event_context_extractor:
parent_context = event_context_extractor(lambda_event)
else:
parent_context = _default_event_context_extractor(lambda_event)

return parent_context
return event_context_extractor(lambda_event)


def _set_api_gateway_v1_proxy_attributes(
Expand Down Expand Up @@ -286,14 +255,15 @@ def _instrument(
flush_timeout,
event_context_extractor: Callable[[Any], Context],
tracer_provider: TracerProvider = None,
disable_aws_context_propagation: bool = False,
meter_provider: MeterProvider = None,
):

# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
call_wrapped, instance, args, kwargs
):

orig_handler_name = ".".join(
[wrapped_module_name, wrapped_function_name]
)
Expand All @@ -303,7 +273,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
parent_context = _determine_parent_context(
lambda_event,
event_context_extractor,
disable_aws_context_propagation,
)

try:
Expand Down Expand Up @@ -451,9 +420,6 @@ def _instrument(self, **kwargs):
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
``disable_aws_context_propagation``: By default, this instrumentation
will try to read the context from the `_X_AMZN_TRACE_ID` environment
variable set by Lambda, set this to `True` to disable this behavior.
"""
lambda_handler = os.environ.get(ORIG_HANDLER, os.environ.get(_HANDLER))
# pylint: disable=attribute-defined-outside-init
Expand All @@ -475,16 +441,6 @@ def _instrument(self, **kwargs):
flush_timeout_env,
)

disable_aws_context_propagation = kwargs.get(
"disable_aws_context_propagation", False
) or os.getenv(
OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION, "False"
).strip().lower() in (
"true",
"1",
"t",
)

_instrument(
self._wrapped_module_name,
self._wrapped_function_name,
Expand All @@ -493,7 +449,6 @@ def _instrument(self, **kwargs):
"event_context_extractor", _default_event_context_extractor
),
tracer_provider=kwargs.get("tracer_provider"),
disable_aws_context_propagation=disable_aws_context_propagation,
meter_provider=kwargs.get("meter_provider"),
)

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
# limitations under the License.
import os
from dataclasses import dataclass
from importlib import import_module
from importlib import import_module, reload
from typing import Any, Callable, Dict
from unittest import mock

from opentelemetry import propagate
from opentelemetry.environment_variables import OTEL_PROPAGATORS
from opentelemetry.instrumentation.aws_lambda import (
_HANDLER,
_X_AMZN_TRACE_ID,
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT,
OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION,
AwsLambdaInstrumentor,
)
from opentelemetry.propagate import get_global_textmap
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self, aws_request_id, invoked_function_arn):
)

MOCK_XRAY_TRACE_ID = 0x5FB7331105E8BB83207FA31D4D9CDB4C

MOCK_XRAY_TRACE_ID_STR = f"{MOCK_XRAY_TRACE_ID:x}"
MOCK_XRAY_PARENT_SPAN_ID = 0x3328B8445A6DBAD2
MOCK_XRAY_TRACE_CONTEXT_COMMON = f"Root={TRACE_ID_VERSION}-{MOCK_XRAY_TRACE_ID_STR[:TRACE_ID_FIRST_PART_LENGTH]}-{MOCK_XRAY_TRACE_ID_STR[TRACE_ID_FIRST_PART_LENGTH:]};Parent={MOCK_XRAY_PARENT_SPAN_ID:x}"
Expand All @@ -81,6 +82,7 @@ def mock_execute_lambda(event=None):
"""Mocks the AWS Lambda execution.
NOTE: We don't use `moto`'s `mock_lambda` because we are not instrumenting
calls to AWS Lambda using the AWS SDK. Instead, we are instrumenting AWS
Lambda itself.
Expand Down Expand Up @@ -122,10 +124,13 @@ def test_active_tracing(self):
{
**os.environ,
# Using Active tracing
OTEL_PROPAGATORS: "xray_lambda",
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
},
)

test_env_patch.start()
reload(propagate)

AwsLambdaInstrumentor().instrument()

Expand Down Expand Up @@ -173,8 +178,7 @@ class TestCase:
xray_traceid: str
expected_state_value: str = None
expected_trace_state_len: int = 0
disable_aws_context_propagation: bool = False
disable_aws_context_propagation_envvar: str = ""
propagators: str = "tracecontext"

def custom_event_context_extractor(lambda_event):
return get_global_textmap().extract(lambda_event["foo"]["headers"])
Expand Down Expand Up @@ -226,9 +230,10 @@ def custom_event_context_extractor(lambda_event):
expected_traceid=MOCK_XRAY_TRACE_ID,
expected_parentid=MOCK_XRAY_PARENT_SPAN_ID,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
propagators="xray_lambda",
),
TestCase(
name="custom_extractor_sampled_xray_disable_aws_propagation",
name="custom_extractor_sampled_xray",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
Expand All @@ -238,24 +243,21 @@ def custom_event_context_extractor(lambda_event):
}
}
},
disable_aws_context_propagation=True,
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
),
TestCase(
name="no_custom_extractor_xray_disable_aws_propagation_via_env_var",
name="no_custom_extractor_xray",
custom_extractor=None,
context={
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
},
disable_aws_context_propagation=False,
disable_aws_context_propagation_envvar="true",
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
Expand All @@ -264,21 +266,21 @@ def custom_event_context_extractor(lambda_event):
),
]
for test in tests:

test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: test.xray_traceid,
OTEL_LAMBDA_DISABLE_AWS_CONTEXT_PROPAGATION: test.disable_aws_context_propagation_envvar,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
OTEL_PROPAGATORS: test.propagators,
},
)
test_env_patch.start()
reload(propagate)

AwsLambdaInstrumentor().instrument(
event_context_extractor=test.custom_extractor,
disable_aws_context_propagation=test.disable_aws_context_propagation,
)
mock_execute_lambda(test.context)
spans = self.memory_exporter.get_finished_spans()
Expand Down Expand Up @@ -374,6 +376,7 @@ def test_lambda_handles_invalid_event_source(self):
},
)
test_env_patch.start()
reload(propagate)

AwsLambdaInstrumentor().instrument()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from opentelemetry.propagators.aws.aws_xray_propagator import AwsXRayPropagator
from opentelemetry.propagators.aws.aws_xray_propagator import (
AwsXRayLambdaPropagator,
AwsXRayPropagator,
)

__all__ = ["AwsXRayPropagator"]
__all__ = ["AwsXRayPropagator", "AwsXRayLambdaPropagator"]
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def fields(self):
return {TRACE_HEADER_KEY}


class AwsXrayLambdaPropagator(AwsXRayPropagator):
class AwsXRayLambdaPropagator(AwsXRayPropagator):
"""Implementation of the AWS X-Ray Trace Header propagation protocol but
with special handling for Lambda's ``_X_AMZN_TRACE_ID` environment
with special handling for Lambda's ``_X_AMZN_TRACE_ID`` environment
variable.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from opentelemetry.context import get_current
from opentelemetry.propagators.aws.aws_xray_propagator import (
TRACE_HEADER_KEY,
AwsXrayLambdaPropagator,
AwsXRayLambdaPropagator,
)
from opentelemetry.propagators.textmap import DefaultGetter
from opentelemetry.sdk.trace import ReadableSpan
Expand All @@ -40,7 +40,7 @@ class AwsXRayLambdaPropagatorTest(TestCase):
def test_extract_no_environment_variable(self):

actual_context = get_current_span(
AwsXrayLambdaPropagator().extract(
AwsXRayLambdaPropagator().extract(
{}, context=get_current(), getter=DefaultGetter()
)
).get_span_context()
Expand All @@ -57,7 +57,7 @@ def test_extract_no_environment_variable_valid_context(self):
with use_span(NonRecordingSpan(SpanContext(1, 2, False))):

actual_context = get_current_span(
AwsXrayLambdaPropagator().extract(
AwsXRayLambdaPropagator().extract(
{}, context=get_current(), getter=DefaultGetter()
)
).get_span_context()
Expand All @@ -83,7 +83,7 @@ def test_extract_no_environment_variable_valid_context(self):
def test_extract_from_environment_variable(self):

actual_context = get_current_span(
AwsXrayLambdaPropagator().extract(
AwsXRayLambdaPropagator().extract(
{}, context=get_current(), getter=DefaultGetter()
)
).get_span_context()
Expand All @@ -108,7 +108,7 @@ def test_extract_from_environment_variable(self):
)
def test_add_link_from_environment_variable(self):

propagator = AwsXrayLambdaPropagator()
propagator = AwsXRayLambdaPropagator()

default_getter = DefaultGetter()

Expand Down

0 comments on commit b1d44c1

Please sign in to comment.