Skip to content

Commit

Permalink
fix asynchonous unary call traces
Browse files Browse the repository at this point in the history
  • Loading branch information
sengjea authored and sengjea committed Jun 7, 2021
1 parent 5b125b1 commit 691bac3
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,6 @@
from opentelemetry.trace.status import Status, StatusCode


class _GuardedSpan:
def __init__(self, span):
self.span = span
self.generated_span = None
self._engaged = True

def __enter__(self):
self.generated_span = self.span.__enter__()
return self

def __exit__(self, *args, **kwargs):
if self._engaged:
self.generated_span = None
return self.span.__exit__(*args, **kwargs)
return False

def release(self):
self._engaged = False
return self.span


class _CarrierSetter(Setter):
"""We use a custom setter in order to be able to lower case
keys as is required by grpc.
Expand All @@ -68,7 +47,7 @@ def set(self, carrier: MutableMapping[str, str], key: str, value: str):

def _make_future_done_callback(span, rpc_info):
def callback(response_future):
with span:
with trace.use_span(span, end_on_exit=True):
code = response_future.code()
if code != grpc.StatusCode.OK:
rpc_info.error = code
Expand All @@ -94,17 +73,17 @@ def _start_span(self, method):
SpanAttributes.RPC_SERVICE: service,
}

return self._tracer.start_as_current_span(
name=method, kind=trace.SpanKind.CLIENT, attributes=attributes
return self._tracer.start_span(
name=method, kind=trace.SpanKind.CLIENT, attributes=attributes,
)

# pylint:disable=no-self-use
def _trace_result(self, guarded_span, rpc_info, result):
def _trace_result(self, span, rpc_info, result):
# If the RPC is called asynchronously, release the guard and add a
# callback so that the span can be finished once the future is done.
if isinstance(result, grpc.Future):
result.add_done_callback(
_make_future_done_callback(guarded_span.release(), rpc_info)
_make_future_done_callback(span, rpc_info)
)
return result
response = result
Expand All @@ -115,41 +94,43 @@ def _trace_result(self, guarded_span, rpc_info, result):
if isinstance(result, tuple):
response = result[0]
rpc_info.response = response

span.end()
return result

def _start_guarded_span(self, *args, **kwargs):
return _GuardedSpan(self._start_span(*args, **kwargs))

def intercept_unary(self, request, metadata, client_info, invoker):
if not metadata:
mutable_metadata = OrderedDict()
else:
mutable_metadata = OrderedDict(metadata)

with self._start_guarded_span(client_info.full_method) as guarded_span:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request,
)

span = self._start_span(client_info.full_method)
with trace.use_span(span, record_exception=False, set_status_on_exception=False):
try:
result = invoker(request, metadata)
except grpc.RpcError as err:
guarded_span.generated_span.set_status(
Status(StatusCode.ERROR)
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request,
)
guarded_span.generated_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0]
)
raise err

return self._trace_result(guarded_span, rpc_info, result)
result = invoker(request, metadata)
except Exception as exc:
if isinstance(exc, grpc.RpcError):
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, exc.code().value[0]
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description="{}: {}".format(type(exc).__name__, exc),
)
)
span.record_exception(exc)
span.end()
raise exc
return self._trace_result(span, rpc_info, result)

# For RPCs that stream responses, the result can be a generator. To record
# the span across the generated responses and detect any errors, we wrap
Expand All @@ -162,7 +143,8 @@ def _intercept_server_stream(
else:
mutable_metadata = OrderedDict(metadata)

with self._start_span(client_info.full_method) as span:
span = self._start_span(client_info.full_method)
with trace.use_span(span, end_on_exit=True):
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())
rpc_info = RpcInfo(
Expand Down Expand Up @@ -199,27 +181,34 @@ def intercept_stream(
else:
mutable_metadata = OrderedDict(metadata)

with self._start_guarded_span(client_info.full_method) as guarded_span:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())
rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request_or_iterator,
)
span = self._start_span(client_info.full_method)
with trace.use_span(span, record_exception=False, set_status_on_exception=False):
try:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())
rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request_or_iterator,
)

rpc_info.request = request_or_iterator
rpc_info.request = request_or_iterator

try:
result = invoker(request_or_iterator, metadata)
except grpc.RpcError as err:
guarded_span.generated_span.set_status(
Status(StatusCode.ERROR)
except Exception as exc:
if isinstance(exc, grpc.RpcError):
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, exc.code().value[0]
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description="{}: {}".format(type(exc).__name__, exc),
)
)
guarded_span.generated_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0],
)
raise err
span.record_exception(exc)
span.end()
raise exc

return self._trace_result(guarded_span, rpc_info, result)
return self._trace_result(span, rpc_info, result)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def simple_method(stub, error=False):
stub.SimpleMethod(request)


def simple_method_future(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
return stub.SimpleMethod.future(request)


def client_streaming_method(stub, error=False):
# create a generator
def request_messages():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
client_streaming_method,
server_streaming_method,
simple_method,
simple_method_future,
)
from ._server import create_test_server
from .protobuf.test_server_pb2 import Request
Expand Down Expand Up @@ -100,6 +101,20 @@ def tearDown(self):
self.server.stop(None)
self.channel.close()

def test_unary_unary_future(self):
simple_method_future(self._stub).result()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]

self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod")
self.assertIs(span.kind, trace.SpanKind.CLIENT)

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(
span, opentelemetry.instrumentation.grpc
)

def test_unary_unary(self):
simple_method(self._stub)
spans = self.memory_exporter.get_finished_spans()
Expand Down

0 comments on commit 691bac3

Please sign in to comment.