Skip to content

Commit

Permalink
Sync with Remove setters and getters
Browse files Browse the repository at this point in the history
Fixes #issue_371
  • Loading branch information
ocelotl committed Mar 22, 2021
1 parent f8e51c4 commit a1bda20
Show file tree
Hide file tree
Showing 29 changed files with 168 additions and 356 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- 'release/*'
pull_request:
env:
CORE_REPO_SHA: d3694fc520f8542b232fd1065133286f4591dcec
CORE_REPO_SHA: 400e891609e91df375cbfc696b42bff983350d4a

jobs:
build:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python-contrib/compare/v0.18b0...HEAD)
- Remove getters and setters from propagators
([#372](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/364))
- Updated instrumentations to use `opentelemetry.trace.use_span` instead of `Tracer.use_span()`
([#364](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/364))

Expand Down
5 changes: 0 additions & 5 deletions docs/nitpick-exceptions.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
class_references=
; TODO: Understand why sphinx is not able to find this local class
opentelemetry.propagators.textmap.TextMapPropagator
; - AwsXRayFormat
opentelemetry.propagators.textmap.DictGetter
; API
opentelemetry.propagators.textmap.Getter
; - DatadogFormat
; - AWSXRayFormat
opentelemetry.sdk.trace.id_generator.IdGenerator
; - AwsXRayIdGenerator
TextMapPropagatorT
; - AwsXRayFormat.extract

anys=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import typing
from typing import Dict, Optional

from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.exporter.datadog import constants
from opentelemetry.propagators.textmap import (
Getter,
Setter,
TextMapPropagator,
TextMapPropagatorT,
)
from opentelemetry.propagators.textmap import TextMapPropagator
from opentelemetry.trace import get_current_span, set_span_in_context


Expand All @@ -35,24 +30,15 @@ class DatadogFormat(TextMapPropagator):
ORIGIN_KEY = "x-datadog-origin"

def extract(
self,
getter: Getter[TextMapPropagatorT],
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
self, carrier: Dict[str, str], context: Optional[Context] = None,
) -> Context:
trace_id = extract_first_element(
getter.get(carrier, self.TRACE_ID_KEY)
)
trace_id = carrier.get(self.TRACE_ID_KEY)

span_id = extract_first_element(
getter.get(carrier, self.PARENT_ID_KEY)
)
span_id = carrier.get(self.PARENT_ID_KEY)

sampled = extract_first_element(
getter.get(carrier, self.SAMPLING_PRIORITY_KEY)
)
sampled = carrier.get(self.SAMPLING_PRIORITY_KEY)

origin = extract_first_element(getter.get(carrier, self.ORIGIN_KEY))
origin = carrier.get(self.ORIGIN_KEY)

trace_flags = trace.TraceFlags()
if sampled and int(sampled) in (
Expand Down Expand Up @@ -80,33 +66,22 @@ def extract(
)

def inject(
self,
set_in_carrier: Setter[TextMapPropagatorT],
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
self, carrier: Dict[str, str], context: Optional[Context] = None,
) -> None:
span = get_current_span(context)
span_context = span.get_span_context()
if span_context == trace.INVALID_SPAN_CONTEXT:
return
sampled = (trace.TraceFlags.SAMPLED & span.context.trace_flags) != 0
set_in_carrier(
carrier, self.TRACE_ID_KEY, format_trace_id(span.context.trace_id),
)
set_in_carrier(
carrier, self.PARENT_ID_KEY, format_span_id(span.context.span_id)
)
set_in_carrier(
carrier,
self.SAMPLING_PRIORITY_KEY,
str(constants.AUTO_KEEP if sampled else constants.AUTO_REJECT),
carrier[self.TRACE_ID_KEY] = format_trace_id(span.context.trace_id)
carrier[self.PARENT_ID_KEY] = format_span_id(span.context.span_id)
carrier[self.SAMPLING_PRIORITY_KEY] = str(
constants.AUTO_KEEP if sampled else constants.AUTO_REJECT
)
if constants.DD_ORIGIN in span.context.trace_state:
set_in_carrier(
carrier,
self.ORIGIN_KEY,
span.context.trace_state[constants.DD_ORIGIN],
)
carrier[self.ORIGIN_KEY] = span.context.trace_state[
constants.DD_ORIGIN
]

@property
def fields(self):
Expand All @@ -131,11 +106,3 @@ def format_trace_id(trace_id: int) -> str:
def format_span_id(span_id: int) -> str:
"""Format the span id for Datadog."""
return str(span_id)


def extract_first_element(
items: typing.Iterable[TextMapPropagatorT],
) -> typing.Optional[TextMapPropagatorT]:
if items is None:
return None
return next(iter(items), None)
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

from opentelemetry import trace as trace_api
from opentelemetry.exporter.datadog import constants, propagator
from opentelemetry.propagators.textmap import DictGetter
from opentelemetry.sdk import trace
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import get_current_span, set_span_in_context

FORMAT = propagator.DatadogFormat()

carrier_getter = DictGetter()


class TestDatadogFormat(unittest.TestCase):
@classmethod
Expand All @@ -45,7 +42,6 @@ def test_malformed_headers(self):
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = get_current_span(
FORMAT.extract(
carrier_getter,
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
Expand All @@ -63,7 +59,7 @@ def test_missing_trace_id(self):
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
}

ctx = FORMAT.extract(carrier_getter, carrier)
ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)

Expand All @@ -73,15 +69,14 @@ def test_missing_parent_id(self):
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
}

ctx = FORMAT.extract(carrier_getter, carrier)
ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)

def test_context_propagation(self):
"""Test the propagation of Datadog headers."""
parent_span_context = get_current_span(
FORMAT.extract(
carrier_getter,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
Expand Down Expand Up @@ -118,7 +113,7 @@ def test_context_propagation(self):

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)
FORMAT.inject(child_carrier, context=child_context)

self.assertEqual(
child_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
Expand All @@ -138,7 +133,6 @@ def test_sampling_priority_auto_reject(self):
"""Test sampling priority rejected."""
parent_span_context = get_current_span(
FORMAT.extract(
carrier_getter,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
Expand All @@ -165,7 +159,7 @@ def test_sampling_priority_auto_reject(self):

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)
FORMAT.inject(child_carrier, context=child_context)

self.assertEqual(
child_carrier[FORMAT.SAMPLING_PRIORITY_KEY],
Expand All @@ -178,8 +172,6 @@ def test_fields(self, mock_get_current_span):

tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider")

mock_set_in_carrier = Mock()

mock_get_current_span.configure_mock(
**{
"return_value": Mock(
Expand All @@ -193,13 +185,10 @@ def test_fields(self, mock_get_current_span):
}
)

carrier = {}

with tracer.start_as_current_span("parent"):
with tracer.start_as_current_span("child"):
FORMAT.inject(mock_set_in_carrier, {})

inject_fields = set()

for call in mock_set_in_carrier.mock_calls:
inject_fields.add(call[1][1])
FORMAT.inject(carrier)

self.assertEqual(FORMAT.fields, inject_fields)
self.assertEqual(FORMAT.fields, carrier.keys())
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async def on_request_start(
trace.set_span_in_context(trace_config_ctx.span)
)

inject(type(params.headers).__setitem__, params.headers)
inject(params.headers)

async def on_request_end(
unused_session: aiohttp.ClientSession,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# limitations under the License.

"""
The opentelemetry-instrumentation-asgi package provides an ASGI middleware that can be used
on any ASGI framework (such as Django-channels / Quart) to track requests
timing through OpenTelemetry.
The opentelemetry-instrumentation-asgi package provides an ASGI middleware that
can be used on any ASGI framework (such as Django-channels / Quart) to track
requests timing through OpenTelemetry.
"""

import typing
import urllib
from functools import wraps
from typing import Tuple
Expand All @@ -29,46 +28,37 @@
from opentelemetry.instrumentation.asgi.version import __version__ # noqa
from opentelemetry.instrumentation.utils import http_status_to_status_code
from opentelemetry.propagate import extract
from opentelemetry.propagators.textmap import DictGetter
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util.http import BaseCustomGetDictionary


class CarrierGetter(DictGetter):
def get(
self, carrier: dict, key: str
) -> typing.Optional[typing.List[str]]:
"""Getter implementation to retrieve a HTTP header value from the ASGI
scope.
class _ASGICustomGetDictionary(BaseCustomGetDictionary):
def get(self, name, default=None):

Args:
carrier: ASGI scope object
key: header name in scope
Returns:
A list with a single string with the header value if it exists,
else None.
"""
headers = carrier.get("headers")
if not headers:
return None
if "headers" not in self.keys():
return default

# ASGI header keys are in lower case
name = name.lower()

# asgi header keys are in lower case
key = key.lower()
decoded = [
_value.decode("utf8")
for (_key, _value) in headers
if _key.decode("utf8") == key
value.decode("utf8")
for key, value in self["headers"]
if key.decode("utf8") == name
]
if not decoded:
return None
return decoded

if not decoded:
return default

carrier_getter = CarrierGetter()
return decoded


def collect_request_attributes(scope):
"""Collects HTTP request attributes from the ASGI scope and returns a
dictionary to be used as span creation attributes."""

asgi_scope = _ASGICustomGetDictionary(scope)

server_host, port, http_url = get_host_port_url_tuple(scope)
query_string = scope.get("query_string")
if query_string and http_url:
Expand All @@ -88,10 +78,10 @@ def collect_request_attributes(scope):
if http_method:
result["http.method"] = http_method

http_host_value_list = carrier_getter.get(scope, "host")
http_host_value_list = asgi_scope.get("host")
if http_host_value_list:
result["http.server_name"] = ",".join(http_host_value_list)
http_user_agent = carrier_getter.get(scope, "user-agent")
http_user_agent = asgi_scope.get("user-agent")
if http_user_agent:
result["http.user_agent"] = http_user_agent[0]

Expand Down Expand Up @@ -186,7 +176,7 @@ async def __call__(self, scope, receive, send):
if self.excluded_urls and self.excluded_urls.url_disabled(url):
return await self.app(scope, receive, send)

token = context.attach(extract(carrier_getter, scope))
token = context.attach(extract(scope))
span_name, additional_attributes = self.span_details_callback(scope)

try:
Expand Down
Loading

0 comments on commit a1bda20

Please sign in to comment.