Skip to content

Commit

Permalink
adding response_hook to redis instrumentor (#669)
Browse files Browse the repository at this point in the history
  • Loading branch information
ItayGibel-helios authored Sep 14, 2021
1 parent 291e508 commit db636a4
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 71 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `opentelemetry-instrumentation-elasticsearch` Added `response_hook` and `request_hook` callbacks
([#670](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/670))

### Added
- `opentelemetry-instrumentation-redis` added request_hook and response_hook callbacks passed as arguments to the instrument method.
([#669](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/669))

### Changed
- `opentelemetry-instrumentation-botocore` Unpatch botocore Endpoint.prepare_request on uninstrument
([#664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/664))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,43 @@
client = redis.StrictRedis(host="localhost", port=6379)
client.get("my-key")
The `instrument` method accepts the following keyword args:
tracer_provider (TracerProvider) - an optional tracer provider
request_hook (Callable) - a function with extra user-defined logic to be performed before performing the request
this function signature is: def request_hook(span: Span, instance: redis.connection.Connection, args, kwargs) -> None
response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request
this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None
for example:
.. code: python
from opentelemetry.instrumentation.redis import RedisInstrumentor
import redis
def request_hook(span, instance, args, kwargs):
if span and span.is_recording():
span.set_attribute("custom_user_attribute_from_request_hook", "some-value")
def response_hook(span, instance, response):
if span and span.is_recording():
span.set_attribute("custom_user_attribute_from_response_hook", "some-value")
# Instrument redis with hooks
RedisInstrumentor().instrument(request_hook=request_hook, response_hook=response_hook)
# This will report a span with the default settings and the custom attributes added from the hooks
client = redis.StrictRedis(host="localhost", port=6379)
client.get("my-key")
API
---
"""

from typing import Collection
import typing
from typing import Any, Collection

import redis
from wrapt import wrap_function_wrapper
Expand All @@ -57,9 +89,19 @@
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

_DEFAULT_SERVICE = "redis"

_RequestHookT = typing.Optional[
typing.Callable[
[Span, redis.connection.Connection, typing.List, typing.Dict], None
]
]
_ResponseHookT = typing.Optional[
typing.Callable[[Span, redis.connection.Connection, Any], None]
]


def _set_connection_attributes(span, conn):
if not span.is_recording():
Expand All @@ -70,42 +112,68 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _traced_execute_command(func, instance, args, kwargs):
tracer = getattr(redis, "_opentelemetry_tracer")
query = _format_command_args(args)
name = ""
if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
return func(*args, **kwargs)


def _traced_execute_pipeline(func, instance, args, kwargs):
tracer = getattr(redis, "_opentelemetry_tracer")

cmds = [_format_command_args(c) for c, _ in instance.command_stack]
resource = "\n".join(cmds)

span_name = " ".join([args[0] for args, _ in instance.command_stack])

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(instance.command_stack)
)
return func(*args, **kwargs)
def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
name = ""
if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
if callable(request_hook):
request_hook(span, instance, args, kwargs)
response = func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
cmds = [_format_command_args(c) for c, _ in instance.command_stack]
resource = "\n".join(cmds)

span_name = " ".join([args[0] for args, _ in instance.command_stack])

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(instance.command_stack)
)
response = func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

pipeline_class = (
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
)
redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"

wrap_function_wrapper(
"redis", f"{redis_class}.execute_command", _traced_execute_command
)
wrap_function_wrapper(
"redis.client", f"{pipeline_class}.execute", _traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.client",
f"{pipeline_class}.immediate_execute_command",
_traced_execute_command,
)


class RedisInstrumentor(BaseInstrumentor):
Expand All @@ -117,41 +185,22 @@ def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
"""Instruments the redis module
Args:
**kwargs: Optional arguments
``tracer_provider``: a TracerProvider, defaults to global.
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
"""
tracer_provider = kwargs.get("tracer_provider")
setattr(
redis,
"_opentelemetry_tracer",
trace.get_tracer(
__name__, __version__, tracer_provider=tracer_provider,
),
tracer = trace.get_tracer(
__name__, __version__, tracer_provider=tracer_provider
)
_instrument(
tracer,
request_hook=kwargs.get("request_hook"),
response_hook=kwargs.get("response_hook"),
)

if redis.VERSION < (3, 0, 0):
wrap_function_wrapper(
"redis", "StrictRedis.execute_command", _traced_execute_command
)
wrap_function_wrapper(
"redis.client",
"BasePipeline.execute",
_traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.client",
"BasePipeline.immediate_execute_command",
_traced_execute_command,
)
else:
wrap_function_wrapper(
"redis", "Redis.execute_command", _traced_execute_command
)
wrap_function_wrapper(
"redis.client", "Pipeline.execute", _traced_execute_pipeline
)
wrap_function_wrapper(
"redis.client",
"Pipeline.immediate_execute_command",
_traced_execute_command,
)

def _uninstrument(self, **kwargs):
if redis.VERSION < (3, 0, 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,64 @@ def test_instrument_uninstrument(self):

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_response_hook(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

response_attribute_name = "db.redis.response"

def response_hook(span, conn, response):
span.set_attribute(response_attribute_name, response)

RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, response_hook=response_hook
)

test_value = "test_value"

with mock.patch.object(connection, "send_command"):
with mock.patch.object(
redis_client, "parse_response", return_value=test_value
):
redis_client.get("key")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(
span.attributes.get(response_attribute_name), test_value
)

def test_request_hook(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

custom_attribute_name = "my.request.attribute"

def request_hook(span, conn, args, kwargs):
if span and span.is_recording():
span.set_attribute(custom_attribute_name, args[0])

RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, request_hook=request_hook
)

test_value = "test_value"

with mock.patch.object(connection, "send_command"):
with mock.patch.object(
redis_client, "parse_response", return_value=test_value
):
redis_client.get("key")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(span.attributes.get(custom_attribute_name), "GET")

0 comments on commit db636a4

Please sign in to comment.