Skip to content

Commit

Permalink
opentelemetry-instrumentation-aws-lambda: Adding option to disable …
Browse files Browse the repository at this point in the history
…context propagation (#1466)

* `opentelemetry-instrumentation-aws-lambda`: Adding option to disable context propagation

Adding the following option to disable context propagation `disable_aws_context_propagation`. This is similar to the disableAwsContextPropagation option in the nodejs instrumentation.

* update changelog

* lint

* more lint
  • Loading branch information
Alex Boten authored Nov 23, 2022
1 parent 80d0b89 commit 8dbd142
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 98 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#685](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/685))
- Add metric instrumentation for tornado
([#1252](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1252))
- `opentelemetry-instrumentation-aws-lambda` Add option to disable aws context propagation
([#1466](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1466))

### Added

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context:


def _determine_parent_context(
lambda_event: Any, event_context_extractor: Callable[[Any], Context]
lambda_event: Any,
event_context_extractor: Callable[[Any], Context],
disable_aws_context_propagation: bool = False,
) -> Context:
"""Determine the parent context for the current Lambda invocation.
Expand All @@ -144,17 +146,25 @@ def _determine_parent_context(
Args:
lambda_event: user-defined, so it could be anything, but this
method counts it being a map with a 'headers' key
event_context_extractor: a method which takes the Lambda
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
disable_aws_context_propagation: By default, this instrumentation
will try to read the context from the `_X_AMZN_TRACE_ID` environment
variable set by Lambda, set this to `True` to disable this behavior.
Returns:
A Context with configuration found in the carrier.
"""
parent_context = None

xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)
if not disable_aws_context_propagation:
xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)

if xray_env_var:
parent_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)
if xray_env_var:
parent_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)

if (
parent_context
Expand Down Expand Up @@ -258,6 +268,7 @@ def _instrument(
flush_timeout,
event_context_extractor: Callable[[Any], Context],
tracer_provider: TracerProvider = None,
disable_aws_context_propagation: bool = False,
):
def _instrumented_lambda_handler_call(
call_wrapped, instance, args, kwargs
Expand All @@ -269,7 +280,9 @@ def _instrumented_lambda_handler_call(
lambda_event = args[0]

parent_context = _determine_parent_context(
lambda_event, event_context_extractor
lambda_event,
event_context_extractor,
disable_aws_context_propagation,
)

span_kind = None
Expand Down Expand Up @@ -368,6 +381,9 @@ def _instrument(self, **kwargs):
Event as input and extracts an OTel Context from it. By default,
the context is extracted from the HTTP headers of an API Gateway
request.
``disable_aws_context_propagation``: By default, this instrumentation
will try to read the context from the `_X_AMZN_TRACE_ID` environment
variable set by Lambda, set this to `True` to disable this behavior.
"""
lambda_handler = os.environ.get(ORIG_HANDLER, os.environ.get(_HANDLER))
# pylint: disable=attribute-defined-outside-init
Expand All @@ -377,11 +393,12 @@ def _instrument(self, **kwargs):
) = lambda_handler.rsplit(".", 1)

flush_timeout_env = os.environ.get(
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, ""
OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT, None
)
flush_timeout = 30000
try:
flush_timeout = int(flush_timeout_env)
if flush_timeout_env is not None:
flush_timeout = int(flush_timeout_env)
except ValueError:
logger.warning(
"Could not convert OTEL_INSTRUMENTATION_AWS_LAMBDA_FLUSH_TIMEOUT value %s to int",
Expand All @@ -396,6 +413,9 @@ def _instrument(self, **kwargs):
"event_context_extractor", _default_event_context_extractor
),
tracer_provider=kwargs.get("tracer_provider"),
disable_aws_context_propagation=kwargs.get(
"disable_aws_context_propagation", False
),
)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from importlib import import_module
from typing import Any, Callable, Dict
from unittest import mock

from mocks.api_gateway_http_api_event import (
Expand Down Expand Up @@ -155,103 +157,129 @@ def test_active_tracing(self):
test_env_patch.stop()

def test_parent_context_from_lambda_event(self):
test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
},
)
test_env_patch.start()

AwsLambdaInstrumentor().instrument()

mock_execute_lambda(
{
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
)

spans = self.memory_exporter.get_finished_spans()
@dataclass
class TestCase:
name: str
custom_extractor: Callable[[Any], None]
context: Dict
expected_traceid: int
expected_parentid: int
xray_traceid: str
expected_state_value: str = None
expected_trace_state_len: int = 0
disable_aws_context_propagation: bool = False

assert spans

self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.get_span_context().trace_id, MOCK_W3C_TRACE_ID)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.span_id, MOCK_W3C_PARENT_SPAN_ID)
self.assertEqual(len(parent_context.trace_state), 3)
self.assertEqual(
parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY),
MOCK_W3C_TRACE_STATE_VALUE,
)
self.assertTrue(parent_context.is_remote)

test_env_patch.stop()

def test_using_custom_extractor(self):
def custom_event_context_extractor(lambda_event):
return get_global_textmap().extract(lambda_event["foo"]["headers"])

test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
},
)
test_env_patch.start()

AwsLambdaInstrumentor().instrument(
event_context_extractor=custom_event_context_extractor,
)

mock_execute_lambda(
{
"foo": {
tests = [
TestCase(
name="no_custom_extractor",
custom_extractor=None,
context={
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
}
)

spans = self.memory_exporter.get_finished_spans()

assert spans

self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(span.get_span_context().trace_id, MOCK_W3C_TRACE_ID)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.span_id, MOCK_W3C_PARENT_SPAN_ID)
self.assertEqual(len(parent_context.trace_state), 3)
self.assertEqual(
parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY),
MOCK_W3C_TRACE_STATE_VALUE,
)
self.assertTrue(parent_context.is_remote)

test_env_patch.stop()
},
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
),
TestCase(
name="custom_extractor_not_sampled_xray",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
},
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED,
),
TestCase(
name="custom_extractor_sampled_xray",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
},
expected_traceid=MOCK_XRAY_TRACE_ID,
expected_parentid=MOCK_XRAY_PARENT_SPAN_ID,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
),
TestCase(
name="custom_extractor_sampled_xray_disable_aws_propagation",
custom_extractor=custom_event_context_extractor,
context={
"foo": {
"headers": {
TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME: MOCK_W3C_TRACE_CONTEXT_SAMPLED,
TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME: f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2",
}
}
},
disable_aws_context_propagation=True,
expected_traceid=MOCK_W3C_TRACE_ID,
expected_parentid=MOCK_W3C_PARENT_SPAN_ID,
expected_trace_state_len=3,
expected_state_value=MOCK_W3C_TRACE_STATE_VALUE,
xray_traceid=MOCK_XRAY_TRACE_CONTEXT_SAMPLED,
),
]
for test in tests:
test_env_patch = mock.patch.dict(
"os.environ",
{
**os.environ,
# NOT Active Tracing
_X_AMZN_TRACE_ID: test.xray_traceid,
# NOT using the X-Ray Propagator
OTEL_PROPAGATORS: "tracecontext",
},
)
test_env_patch.start()
AwsLambdaInstrumentor().instrument(
event_context_extractor=test.custom_extractor,
disable_aws_context_propagation=test.disable_aws_context_propagation,
)
mock_execute_lambda(test.context)
spans = self.memory_exporter.get_finished_spans()
assert spans
self.assertEqual(len(spans), 1)
span = spans[0]
self.assertEqual(
span.get_span_context().trace_id, test.expected_traceid
)

parent_context = span.parent
self.assertEqual(
parent_context.trace_id, span.get_span_context().trace_id
)
self.assertEqual(parent_context.span_id, test.expected_parentid)
self.assertEqual(
len(parent_context.trace_state), test.expected_trace_state_len
)
self.assertEqual(
parent_context.trace_state.get(MOCK_W3C_TRACE_STATE_KEY),
test.expected_state_value,
)
self.assertTrue(parent_context.is_remote)
self.memory_exporter.clear()
AwsLambdaInstrumentor().uninstrument()
test_env_patch.stop()

def test_lambda_no_error_with_invalid_flush_timeout(self):

Expand Down

0 comments on commit 8dbd142

Please sign in to comment.