Skip to content

Commit

Permalink
adding response_hook to redis instrumentor
Browse files Browse the repository at this point in the history
  • Loading branch information
ItayGibel-helios committed Sep 9, 2021
1 parent 97e9f2f commit 7fc376e
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 71 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)

### Added
- `opentelemetry-instrumentation-redis` added response_hook callback passed as an argument to the instrument method.
([#669](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/669))

### Changed
- `opentelemetry-instrumentation-botocore` Unpatch botocore Endpoint.prepare_request on uninstrument
([#664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/664))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
API
---
"""

from typing import Collection
import typing
from typing import Any, Collection

import redis
from wrapt import wrap_function_wrapper
Expand All @@ -57,9 +57,14 @@
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

_DEFAULT_SERVICE = "redis"

_ResponseHookT = typing.Optional[
typing.Callable[[Span, redis.connection.Connection, Any], None]
]


def _set_connection_attributes(span, conn):
if not span.is_recording():
Expand All @@ -70,42 +75,64 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _traced_execute_command(func, instance, args, kwargs):
tracer = getattr(redis, "_opentelemetry_tracer")
query = _format_command_args(args)
name = ""
if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
return func(*args, **kwargs)


def _traced_execute_pipeline(func, instance, args, kwargs):
tracer = getattr(redis, "_opentelemetry_tracer")

cmds = [_format_command_args(c) for c, _ in instance.command_stack]
resource = "\n".join(cmds)

span_name = " ".join([args[0] for args, _ in instance.command_stack])

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(instance.command_stack)
)
return func(*args, **kwargs)
def _instrument(
tracer, response_hook: _ResponseHookT = None,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
name = ""
if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
response = func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
cmds = [_format_command_args(c) for c, _ in instance.command_stack]
resource = "\n".join(cmds)

span_name = " ".join([args[0] for args, _ in instance.command_stack])

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(instance.command_stack)
)
response = func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

pipeline_class = (
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
)
redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"

wrap_function_wrapper(
"redis", f"{redis_class}.execute_command", _traced_execute_command
)
wrap_function_wrapper(
"redis.client", f"{pipeline_class}.execute", _traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.client",
f"{pipeline_class}.immediate_execute_command",
_traced_execute_command,
)


class RedisInstrumentor(BaseInstrumentor):
Expand All @@ -117,41 +144,18 @@ def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
"""Instruments the redis module
Args:
**kwargs: Optional arguments
``tracer_provider``: a TracerProvider, defaults to global.
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
"""
tracer_provider = kwargs.get("tracer_provider")
setattr(
redis,
"_opentelemetry_tracer",
trace.get_tracer(
__name__, __version__, tracer_provider=tracer_provider,
),
tracer = trace.get_tracer(
__name__, __version__, tracer_provider=tracer_provider
)

if redis.VERSION < (3, 0, 0):
wrap_function_wrapper(
"redis", "StrictRedis.execute_command", _traced_execute_command
)
wrap_function_wrapper(
"redis.client",
"BasePipeline.execute",
_traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.client",
"BasePipeline.immediate_execute_command",
_traced_execute_command,
)
else:
wrap_function_wrapper(
"redis", "Redis.execute_command", _traced_execute_command
)
wrap_function_wrapper(
"redis.client", "Pipeline.execute", _traced_execute_pipeline
)
wrap_function_wrapper(
"redis.client",
"Pipeline.immediate_execute_command",
_traced_execute_command,
)
_instrument(tracer, response_hook=kwargs.get("response_hook"))

def _uninstrument(self, **kwargs):
if redis.VERSION < (3, 0, 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,34 @@ def test_instrument_uninstrument(self):

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_response_hook(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

response_attribute_name = "db.redis.response"

def response_hook(span, conn, response):
span.set_attribute(response_attribute_name, response)

RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, response_hook=response_hook
)

test_value = "test_value"

with mock.patch.object(connection, "send_command"):
with mock.patch.object(
redis_client, "parse_response", return_value=test_value
):
redis_client.get("key")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(
span.attributes.get(response_attribute_name), test_value
)

0 comments on commit 7fc376e

Please sign in to comment.