From 7ec8c59e2d61d13cb223f50f6a4973c51f8c5da5 Mon Sep 17 00:00:00 2001 From: Prashant Srivastava <50466688+srprash@users.noreply.github.com> Date: Wed, 26 Oct 2022 09:48:54 -0700 Subject: [PATCH] persist original trace header in lambda context (#362) --- aws_xray_sdk/core/lambda_launcher.py | 1 + tests/test_lambda_context.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/aws_xray_sdk/core/lambda_launcher.py b/aws_xray_sdk/core/lambda_launcher.py index f35b2d99..9efccc6b 100644 --- a/aws_xray_sdk/core/lambda_launcher.py +++ b/aws_xray_sdk/core/lambda_launcher.py @@ -142,5 +142,6 @@ def _initialize_context(self, trace_header): entityid=trace_header.parent, sampled=sampled, ) + segment.save_origin_trace_header(trace_header) setattr(self._local, 'segment', segment) setattr(self._local, 'entities', []) diff --git a/tests/test_lambda_context.py b/tests/test_lambda_context.py index 29b1cc42..98e1687c 100644 --- a/tests/test_lambda_context.py +++ b/tests/test_lambda_context.py @@ -8,7 +8,8 @@ TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793' PARENT_ID = '53995c3f42cd8ad8' -HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID) +DATA = 'Foo=Bar' +HEADER_VAR = "Root=%s;Parent=%s;Sampled=1;%s" % (TRACE_ID, PARENT_ID, DATA) os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR context = lambda_launcher.LambdaContext() @@ -26,6 +27,7 @@ def test_facade_segment_generation(): assert segment.id == PARENT_ID assert segment.trace_id == TRACE_ID assert segment.sampled + assert DATA in segment.get_origin_trace_header().to_header_str() def test_put_subsegment(): @@ -43,6 +45,7 @@ def test_put_subsegment(): assert subsegment2.parent_id == subsegment.id assert subsegment.parent_id == segment.id assert subsegment2.parent_segment is segment + assert DATA in subsegment2.parent_segment.get_origin_trace_header().to_header_str() context.end_subsegment() assert context.get_trace_entity().id == subsegment.id @@ -60,6 +63,7 @@ def test_disable(): global_sdk_config.set_sdk_enabled(False) segment = context.get_trace_entity() assert not segment.sampled + assert DATA in segment.get_origin_trace_header().to_header_str() def test_non_initialized():