From 7f46e931e066f738c8d444464f072da663ce479d Mon Sep 17 00:00:00 2001 From: Viktor Ivanov Date: Tue, 30 May 2023 23:06:05 +0100 Subject: [PATCH 1/3] Fix async redis clients tracing --- .../instrumentation/redis/__init__.py | 123 +++++++++++++----- .../tests/test_redis.py | 49 +++++++ 2 files changed, 138 insertions(+), 34 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index c1068bda27..3cebdc59ca 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -157,6 +157,44 @@ def _set_connection_attributes(span, conn): span.set_attribute(key, value) +def _build_span_name(instance, cmd_args): + if len(cmd_args) > 0 and cmd_args[0]: + name = cmd_args[0] + else: + name = instance.connection_pool.connection_kwargs.get("db", 0) + return name + + +def _build_span_meta_data_for_pipeline(instance, sanitize_query): + try: + command_stack = ( + instance.command_stack + if hasattr(instance, "command_stack") + else instance._command_stack + ) + + cmds = [ + _format_command_args( + c.args if hasattr(c, "args") else c[0], sanitize_query + ) + for c in command_stack + ] + resource = "\n".join(cmds) + + span_name = " ".join( + [ + (c.args[0] if hasattr(c, "args") else c[0][0]) + for c in command_stack + ] + ) + except (AttributeError, IndexError): + command_stack = [] + resource = "" + span_name = "" + + return command_stack, resource, span_name + + def _instrument( tracer, request_hook: _RequestHookT = None, @@ -165,11 +203,8 @@ def _instrument( ): def _traced_execute_command(func, instance, args, kwargs): query = _format_command_args(args, sanitize_query) + name = _build_span_name(instance, args) - if len(args) > 0 and args[0]: - name = args[0] - else: - name = instance.connection_pool.connection_kwargs.get("db", 0) with tracer.start_as_current_span( name, kind=trace.SpanKind.CLIENT ) as span: @@ -185,31 +220,11 @@ def _traced_execute_command(func, instance, args, kwargs): return response def _traced_execute_pipeline(func, instance, args, kwargs): - try: - command_stack = ( - instance.command_stack - if hasattr(instance, "command_stack") - else instance._command_stack - ) - - cmds = [ - _format_command_args( - c.args if hasattr(c, "args") else c[0], sanitize_query - ) - for c in command_stack - ] - resource = "\n".join(cmds) - - span_name = " ".join( - [ - (c.args[0] if hasattr(c, "args") else c[0][0]) - for c in command_stack - ] - ) - except (AttributeError, IndexError): - command_stack = [] - resource = "" - span_name = "" + ( + command_stack, + resource, + span_name, + ) = _build_span_meta_data_for_pipeline(instance, sanitize_query) with tracer.start_as_current_span( span_name, kind=trace.SpanKind.CLIENT @@ -254,32 +269,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs): "ClusterPipeline.execute", _traced_execute_pipeline, ) + + async def _async_traced_execute_command(func, instance, args, kwargs): + query = _format_command_args(args, sanitize_query) + name = _build_span_name(instance, args) + + with tracer.start_as_current_span( + name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, query) + _set_connection_attributes(span, instance) + span.set_attribute("db.redis.args_length", len(args)) + if callable(request_hook): + request_hook(span, instance, args, kwargs) + response = await func(*args, **kwargs) + if callable(response_hook): + response_hook(span, instance, response) + return response + + async def _async_traced_execute_pipeline(func, instance, args, kwargs): + ( + command_stack, + resource, + span_name, + ) = _build_span_meta_data_for_pipeline(instance, sanitize_query) + + with tracer.start_as_current_span( + span_name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, resource) + _set_connection_attributes(span, instance) + span.set_attribute( + "db.redis.pipeline_length", len(command_stack) + ) + response = await func(*args, **kwargs) + if callable(response_hook): + response_hook(span, instance, response) + return response + if redis.VERSION >= _REDIS_ASYNCIO_VERSION: wrap_function_wrapper( "redis.asyncio", f"{redis_class}.execute_command", - _traced_execute_command, + _async_traced_execute_command, ) wrap_function_wrapper( "redis.asyncio.client", f"{pipeline_class}.execute", - _traced_execute_pipeline, + _async_traced_execute_pipeline, ) wrap_function_wrapper( "redis.asyncio.client", f"{pipeline_class}.immediate_execute_command", - _traced_execute_command, + _async_traced_execute_command, ) if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: wrap_function_wrapper( "redis.asyncio.cluster", "RedisCluster.execute_command", - _traced_execute_command, + _async_traced_execute_command, ) wrap_function_wrapper( "redis.asyncio.cluster", "ClusterPipeline.execute", - _traced_execute_pipeline, + _async_traced_execute_pipeline, ) diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 56a0df6a0a..a7e3ca7885 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from unittest import mock import redis +import redis.asyncio from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor @@ -21,6 +23,24 @@ from opentelemetry.trace import SpanKind +class AsyncMock: + """A sufficient async mock implementation. + + Python 3.7 doesn't have an inbuilt async mock class, so this is used. + """ + + def __init__(self): + self.mock = mock.Mock() + + async def __call__(self, *args, **kwargs): + f = asyncio.Future() + f.set_result("random") + return f + + def __getattr__(self, item): + return AsyncMock() + + class TestRedis(TestBase): def setUp(self): super().setUp() @@ -87,6 +107,35 @@ def test_instrument_uninstrument(self): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 1) + def test_instrument_uninstrument_async_client_command(self): + redis_client = redis.asyncio.Redis() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.memory_exporter.clear() + + # Test uninstrument + RedisInstrumentor().uninstrument() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + self.memory_exporter.clear() + + # Test instrument again + RedisInstrumentor().instrument() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + def test_response_hook(self): redis_client = redis.Redis() connection = redis.connection.Connection() From a4e2f5e28d956e5ea513aa4586910d10a0fc1cb3 Mon Sep 17 00:00:00 2001 From: Viktor Ivanov Date: Tue, 30 May 2023 23:06:13 +0100 Subject: [PATCH 2/3] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee92bdab7a..97f07cb211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed +- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830)) + ## Version 1.18.0/0.39b0 (2023-05-10) - `opentelemetry-instrumentation-system-metrics` Add `process.` prefix to `runtime.memory`, `runtime.cpu.time`, and `runtime.gc_count`. Change `runtime.memory` from count to UpDownCounter. ([#1735](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1735)) From a00990d7ee7c8aaee2f3d73dee3935a5fbe8b7eb Mon Sep 17 00:00:00 2001 From: Viktor Ivanov Date: Wed, 31 May 2023 21:15:29 +0100 Subject: [PATCH 3/3] Add functional integration tests and fix linting issues --- .../instrumentation/redis/__init__.py | 1 + .../tests/test_redis.py | 6 +- .../tests/redis/test_redis_functional.py | 132 ++++++++++++++++++ 3 files changed, 136 insertions(+), 3 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 3cebdc59ca..9495f38896 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -195,6 +195,7 @@ def _build_span_meta_data_for_pipeline(instance, sanitize_query): return command_stack, resource, span_name +# pylint: disable=R0915 def _instrument( tracer, request_hook: _RequestHookT = None, diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index a7e3ca7885..35cf3ac215 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -33,9 +33,9 @@ def __init__(self): self.mock = mock.Mock() async def __call__(self, *args, **kwargs): - f = asyncio.Future() - f.set_result("random") - return f + future = asyncio.Future() + future.set_result("random") + return future def __getattr__(self, item): return AsyncMock() diff --git a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py index 675a37fa9f..bbd7b17e2c 100644 --- a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py +++ b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from time import time_ns import redis import redis.asyncio @@ -326,6 +327,29 @@ def test_basics(self): ) self.assertEqual(span.attributes.get("db.redis.args_length"), 2) + def test_execute_command_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + nonlocal coro_created_time + nonlocal finish_time + + # delay coroutine creation from coroutine execution + coro = self.redis_client.get("foo") + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_pipeline_traced(self): async def pipeline_simple(): async with self.redis_client.pipeline( @@ -348,6 +372,35 @@ async def pipeline_simple(): ) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) + def test_pipeline_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + async with self.redis_client.pipeline( + transaction=False + ) as pipeline: + nonlocal coro_created_time + nonlocal finish_time + pipeline.set("blah", 32) + pipeline.rpush("foo", "éé") + pipeline.hgetall("xxx") + + # delay coroutine creation from coroutine execution + coro = pipeline.execute() + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_pipeline_immediate(self): async def pipeline_immediate(): async with self.redis_client.pipeline() as pipeline: @@ -367,6 +420,33 @@ async def pipeline_immediate(): span.attributes.get(SpanAttributes.DB_STATEMENT), "SET b 2" ) + def test_pipeline_immediate_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + async with self.redis_client.pipeline( + transaction=False + ) as pipeline: + nonlocal coro_created_time + nonlocal finish_time + pipeline.set("a", 1) + + # delay coroutine creation from coroutine execution + coro = pipeline.immediate_execute_command("SET", "b", 2) + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_parent(self): """Ensure OpenTelemetry works with redis.""" ot_tracer = trace.get_tracer("redis_svc") @@ -416,6 +496,29 @@ def test_basics(self): ) self.assertEqual(span.attributes.get("db.redis.args_length"), 2) + def test_execute_command_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + nonlocal coro_created_time + nonlocal finish_time + + # delay coroutine creation from coroutine execution + coro = self.redis_client.get("foo") + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_pipeline_traced(self): async def pipeline_simple(): async with self.redis_client.pipeline( @@ -438,6 +541,35 @@ async def pipeline_simple(): ) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) + def test_pipeline_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + async with self.redis_client.pipeline( + transaction=False + ) as pipeline: + nonlocal coro_created_time + nonlocal finish_time + pipeline.set("blah", 32) + pipeline.rpush("foo", "éé") + pipeline.hgetall("xxx") + + # delay coroutine creation from coroutine execution + coro = pipeline.execute() + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_parent(self): """Ensure OpenTelemetry works with redis.""" ot_tracer = trace.get_tracer("redis_svc")