Skip to content

Commit

Permalink
converted textmap propagator getter to a class and added keys method
Browse files Browse the repository at this point in the history
  • Loading branch information
nprajilesh committed Oct 2, 2020
1 parent 14fad78 commit ee81915
Show file tree
Hide file tree
Showing 20 changed files with 254 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,24 @@ class DatadogFormat(TextMapPropagator):

def extract(
self,
get_from_carrier: Getter[TextMapPropagatorT],
get_from_carrier: Getter,
carrier: TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
trace_id = extract_first_element(
get_from_carrier(carrier, self.TRACE_ID_KEY)
get_from_carrier.get(carrier, self.TRACE_ID_KEY)
)

span_id = extract_first_element(
get_from_carrier(carrier, self.PARENT_ID_KEY)
get_from_carrier.get(carrier, self.PARENT_ID_KEY)
)

sampled = extract_first_element(
get_from_carrier(carrier, self.SAMPLING_PRIORITY_KEY)
get_from_carrier.get(carrier, self.SAMPLING_PRIORITY_KEY)
)

origin = extract_first_element(
get_from_carrier(carrier, self.ORIGIN_KEY)
get_from_carrier.get(carrier, self.ORIGIN_KEY)
)

trace_flags = trace.TraceFlags()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@
FORMAT = propagator.DatadogFormat()


def get_as_list(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []
class Getter:
@staticmethod
def get(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []

@staticmethod
def keys(dict_object):
return dict_object.keys()


get_as_list = Getter()


class TestDatadogFormat(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,32 @@
from opentelemetry.trace.status import Status, StatusCanonicalCode


def get_header_from_scope(scope: dict, header_name: str) -> typing.List[str]:
"""Retrieve a HTTP header value from the ASGI scope.
class Getter:
@staticmethod
def get(scope: dict, header_name: str) -> typing.List[str]:
"""Retrieve a HTTP header value from the ASGI scope.
Returns:
A list with a single string with the header value if it exists, else an empty list.
"""
headers = scope.get("headers")
return [
value.decode("utf8")
for (key, value) in headers
if key.decode("utf8") == header_name
]
Returns:
A list with a single string with the header value if it exists, else an empty list.
"""
headers = scope.get("headers")
return [
value.decode("utf8")
for (key, value) in headers
if key.decode("utf8") == header_name
]

@staticmethod
def keys(scope: dict) -> typing.List[str]:
"""Retrieve all the HTTP header keys for an ASGI scope..
Returns:
A list with all the keys in scope.
"""
return scope.keys()


get_header_from_scope = Getter()


def collect_request_attributes(scope):
Expand Down Expand Up @@ -72,10 +86,10 @@ def collect_request_attributes(scope):
http_method = scope.get("method")
if http_method:
result["http.method"] = http_method
http_host_value = ",".join(get_header_from_scope(scope, "host"))
http_host_value = ",".join(get_header_from_scope.get(scope, "host"))
if http_host_value:
result["http.server_name"] = http_host_value
http_user_agent = get_header_from_scope(scope, "user-agent")
http_user_agent = get_header_from_scope.get(scope, "user-agent")
if len(http_user_agent) > 0:
result["http.user_agent"] = http_user_agent[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _trace_prerun(self, *args, **kwargs):
return

request = task.request
carrier_extractor = Getter()
tracectx = propagators.extract(carrier_extractor, request) or {}
parent = get_current_span(tracectx)

Expand Down Expand Up @@ -248,8 +249,14 @@ def _trace_retry(*args, **kwargs):
span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason))


def carrier_extractor(carrier, key):
value = getattr(carrier, key, [])
if isinstance(value, str) or not isinstance(value, Iterable):
value = (value,)
return value
class Getter:
@staticmethod
def get(carrier, key):
value = getattr(carrier, key, [])
if isinstance(value, str) or not isinstance(value, Iterable):
value = (value,)
return value

@staticmethod
def keys(carrier):
return carrier.keys()
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ def _check_error_code(span, servicer_context, rpc_info):
rpc_info.error = servicer_context.code


class Getter:
@staticmethod
def get(metadata, key) -> List[str]:
md_dict = {md.key: md.value for md in metadata}
return [md_dict[key]] if key in md_dict else []

@staticmethod
def keys(metadata) -> List[str]:
md_dict = {md.key: md.value for md in metadata}
return md_dict.keys()


class OpenTelemetryServerInterceptor(
grpcext.UnaryServerInterceptor, grpcext.StreamServerInterceptor
):
Expand All @@ -121,11 +133,8 @@ def __init__(self, tracer):
def _set_remote_context(self, servicer_context):
metadata = servicer_context.invocation_metadata()
if metadata:
md_dict = {md.key: md.value for md in metadata}

def get_from_grpc_metadata(metadata, key) -> List[str]:
return [md_dict[key]] if key in md_dict else []

get_from_grpc_metadata = Getter()
# Update the context with the traceparent from the RPC metadata.
ctx = propagators.extract(get_from_grpc_metadata, metadata)
token = attach(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,47 @@ def tracer(self) -> "TracerShim":
return self._tracer


class Getter:
"""This class provides an interface that enables extracting propagated
fields from a carrier
"""

@staticmethod
def get(carrier, key):
"""Function that can retrieve zero
or more values from the carrier. In the case that
the value does not exist, returns an empty list.
Args:
carrier: and object which contains values that are
used to construct a Context. This object
must be paired with an appropriate get_from_carrier
which understands how to extract a value from it.
key: key of a field in carrier.
Returns:
first value of the propagation key or an empty list if the key doesn't exist.
"""

value = carrier.get(key)
return [value] if value is not None else []

@staticmethod
def keys(carrier):
"""Function that can retrieve all the keys in a carrier object.
Args:
carrier: and object which contains values that are
used to construct a Context. This object
must be paired with an appropriate get_from_carrier
which understands how to extract a value from it.
Returns:
list of keys from the carrier.
"""

return carrier.keys()


class TracerShim(Tracer):
"""Wraps a :class:`opentelemetry.trace.Tracer` object.
Expand Down Expand Up @@ -706,10 +747,7 @@ def extract(self, format: object, carrier: object):
if format not in self._supported_formats:
raise UnsupportedFormatException

def get_as_list(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []

get_as_list = Getter()
propagator = propagators.get_global_textmap()
ctx = propagator.extract(get_as_list, carrier)
span = get_current_span(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,15 @@ def _log_exception(tracer, func, handler, args, kwargs):
return func(*args, **kwargs)


def _get_header_from_request_headers(
headers: dict, header_name: str
) -> typing.List[str]:
header = headers.get(header_name)
return [header] if header else []
class Getter:
@staticmethod
def get(headers: dict, header_name: str) -> typing.List[str]:
header = headers.get(header_name)
return [header] if header else []

@staticmethod
def keys(headers) -> typing.List[str]:
return headers.keys()


def _get_attributes_from_request(request):
Expand Down Expand Up @@ -206,6 +210,7 @@ def _get_operation_name(handler, request):


def _start_span(tracer, handler, start_time) -> _TraceContext:
_get_header_from_request_headers = Getter()
token = context.attach(
propagators.extract(
_get_header_from_request_headers, handler.request.headers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,31 @@ def hello():
_HTTP_VERSION_PREFIX = "HTTP/"


def get_header_from_environ(
environ: dict, header_name: str
) -> typing.List[str]:
"""Retrieve a HTTP header value from the PEP3333-conforming WSGI environ.
class Getter:
@staticmethod
def get(environ: dict, header_name: str) -> typing.List[str]:
"""Retrieve a HTTP header value from the PEP3333-conforming WSGI environ.
Returns:
A list with a single string with the header value if it exists, else an empty list.
"""
environ_key = "HTTP_" + header_name.upper().replace("-", "_")
value = environ.get(environ_key)
if value is not None:
return [value]
return []
Returns:
A list with a single string with the header value if it exists, else an empty list.
"""
environ_key = "HTTP_" + header_name.upper().replace("-", "_")
value = environ.get(environ_key)
if value is not None:
return [value]
return []

@staticmethod
def keys(environ: dict) -> typing.List[str]:
"""Retrieve all the HTTP header keys for an PEP3333-conforming WSGI environ.
Returns:
A list with all the keys in environ.
"""
return environ.keys()


get_header_from_environ = Getter()


def setifnotnone(dic, key, value):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BaggagePropagator(textmap.TextMapPropagator):

def extract(
self,
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand All @@ -43,7 +43,7 @@ def extract(
context = get_current()

header = _extract_first_element(
get_from_carrier(carrier, self._BAGGAGE_HEADER_NAME)
get_from_carrier.get(carrier, self._BAGGAGE_HEADER_NAME)
)

if not header or len(header) > self.MAX_HEADER_LENGTH:
Expand Down
8 changes: 4 additions & 4 deletions opentelemetry-api/src/opentelemetry/propagators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ def example_route():


def extract(
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
""" Uses the configured propagator to extract a Context from the carrier.
Args:
get_from_carrier: a function that can retrieve zero
or more values from the carrier. In the case that
the value does not exist, return an empty list.
get_from_carrier: an object which contains a get function that can retrieve zero
or more values from the carrier and a keys function that can get all the keys
from carrier.
carrier: and object which contains values that are
used to construct a Context. This object
must be paired with an appropriate get_from_carrier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(

def extract(
self,
get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT],
get_from_carrier: textmap.Getter,
carrier: textmap.TextMapPropagatorT,
context: typing.Optional[Context] = None,
) -> Context:
Expand Down
Loading

0 comments on commit ee81915

Please sign in to comment.