From 9e2dbecedc4fde0642fc1efcaa6fbe9b6f4e47dd Mon Sep 17 00:00:00 2001 From: Thiyagu55 <64461612+Thiyagu55@users.noreply.github.com> Date: Thu, 14 Jul 2022 15:59:34 +0530 Subject: [PATCH] Adding multiple db connections support for django-instrumentation's sqlcommenter (#1187) --- CHANGELOG.md | 2 + .../middleware/sqlcommenter_middleware.py | 60 +++++-------------- .../tests/test_middleware.py | 8 ++- .../tests/test_sqlcommenter.py | 19 +++++- .../opentelemetry/instrumentation/utils.py | 28 ++++++--- 5 files changed, 62 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8669d9752..545183b786 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.12.0rc2-0.32b0...HEAD) +- Adding multiple db connections support for django-instrumentation's sqlcommenter + ([#1187](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1187)) ### Added - `opentelemetry-instrumentation-redis` add support to instrument RedisCluster clients diff --git a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/sqlcommenter_middleware.py b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/sqlcommenter_middleware.py index 556bd92938..5fe51fca52 100644 --- a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/sqlcommenter_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware/sqlcommenter_middleware.py @@ -13,15 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import ExitStack from logging import getLogger from typing import Any, Type, TypeVar -from urllib.parse import quote as urllib_quote # pylint: disable=no-name-in-module from django import conf, get_version -from django.db import connection +from django.db import connections from django.db.backends.utils import CursorDebugWrapper +from opentelemetry.instrumentation.utils import ( + _generate_sql_comment, + _get_opentelemetry_values, +) from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator, ) @@ -44,7 +48,13 @@ def __init__(self, get_response) -> None: self.get_response = get_response def __call__(self, request) -> Any: - with connection.execute_wrapper(_QueryWrapper(request)): + with ExitStack() as stack: + for db_alias in connections: + stack.enter_context( + connections[db_alias].execute_wrapper( + _QueryWrapper(request) + ) + ) return self.get_response(request) @@ -105,49 +115,7 @@ def __call__(self, execute: Type[T], sql, params, many, context) -> T: sql += sql_comment # Add the query to the query log if debugging. - if context["cursor"].__class__ is CursorDebugWrapper: + if isinstance(context["cursor"], CursorDebugWrapper): context["connection"].queries_log.append(sql) return execute(sql, params, many, context) - - -def _generate_sql_comment(**meta) -> str: - """ - Return a SQL comment with comma delimited key=value pairs created from - **meta kwargs. - """ - key_value_delimiter = "," - - if not meta: # No entries added. - return "" - - # Sort the keywords to ensure that caching works and that testing is - # deterministic. It eases visual inspection as well. - return ( - " /*" - + key_value_delimiter.join( - f"{_url_quote(key)}={_url_quote(value)!r}" - for key, value in sorted(meta.items()) - if value is not None - ) - + "*/" - ) - - -def _url_quote(value) -> str: - if not isinstance(value, (str, bytes)): - return value - _quoted = urllib_quote(value) - # Since SQL uses '%' as a keyword, '%' is a by-product of url quoting - # e.g. foo,bar --> foo%2Cbar - # thus in our quoting, we need to escape it too to finally give - # foo,bar --> foo%%2Cbar - return _quoted.replace("%", "%%") - - -def _get_opentelemetry_values() -> dict or None: - """ - Return the OpenTelemetry Trace and Span IDs if Span ID is set in the - OpenTelemetry execution context. - """ - return _propagator.inject({}) diff --git a/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py b/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py index a40a7b82ee..05457de43d 100644 --- a/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py @@ -86,7 +86,13 @@ class TestMiddleware(WsgiTestBase): @classmethod def setUpClass(cls): - conf.settings.configure(ROOT_URLCONF=modules[__name__]) + conf.settings.configure( + ROOT_URLCONF=modules[__name__], + DATABASES={ + "default": {}, + "other": {}, + }, # db.connections gets populated only at first test execution + ) super().setUpClass() def setUp(self): diff --git a/instrumentation/opentelemetry-instrumentation-django/tests/test_sqlcommenter.py b/instrumentation/opentelemetry-instrumentation-django/tests/test_sqlcommenter.py index 682dd5f4e9..b162cc1f2a 100644 --- a/instrumentation/opentelemetry-instrumentation-django/tests/test_sqlcommenter.py +++ b/instrumentation/opentelemetry-instrumentation-django/tests/test_sqlcommenter.py @@ -13,15 +13,16 @@ # limitations under the License. # pylint: disable=no-name-in-module - from unittest.mock import MagicMock, patch +import pytest from django import VERSION, conf from django.http import HttpResponse from django.test.utils import setup_test_environment, teardown_test_environment from opentelemetry.instrumentation.django import DjangoInstrumentor from opentelemetry.instrumentation.django.middleware.sqlcommenter_middleware import ( + SqlCommenter, _QueryWrapper, ) from opentelemetry.test.wsgitestutil import WsgiTestBase @@ -98,3 +99,19 @@ def test_query_wrapper(self, trace_capture): "Select 1 /*app_name='app',controller='view',route='route',traceparent='%%2Atraceparent%%3D%%2700-0000000" "00000000000000000deadbeef-000000000000beef-00'*/", ) + + @patch( + "opentelemetry.instrumentation.django.middleware.sqlcommenter_middleware._QueryWrapper" + ) + def test_multiple_connection_support(self, query_wrapper): + if not DJANGO_2_0: + pytest.skip() + + requests_mock = MagicMock() + get_response = MagicMock() + + sql_instance = SqlCommenter(get_response) + sql_instance(requests_mock) + + # check if query_wrapper is added to the context for 2 databases + self.assertEqual(query_wrapper.call_count, 2) diff --git a/opentelemetry-instrumentation/src/opentelemetry/instrumentation/utils.py b/opentelemetry-instrumentation/src/opentelemetry/instrumentation/utils.py index fea7608388..181d5b6fce 100644 --- a/opentelemetry-instrumentation/src/opentelemetry/instrumentation/utils.py +++ b/opentelemetry-instrumentation/src/opentelemetry/instrumentation/utils.py @@ -25,6 +25,11 @@ from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY # noqa: F401 from opentelemetry.propagate import extract from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, +) + +propagator = TraceContextTextMapPropagator() def extract_attributes_from_object( @@ -119,24 +124,22 @@ def _start_internal_or_server_span( return span, token -_KEY_VALUE_DELIMITER = "," - - -def _generate_sql_comment(**meta): +def _generate_sql_comment(**meta) -> str: """ Return a SQL comment with comma delimited key=value pairs created from **meta kwargs. """ + key_value_delimiter = "," + if not meta: # No entries added. return "" # Sort the keywords to ensure that caching works and that testing is # deterministic. It eases visual inspection as well. - # pylint: disable=consider-using-f-string return ( " /*" - + _KEY_VALUE_DELIMITER.join( - "{}={!r}".format(_url_quote(key), _url_quote(value)) + + key_value_delimiter.join( + f"{_url_quote(key)}={_url_quote(value)!r}" for key, value in sorted(meta.items()) if value is not None ) @@ -155,6 +158,17 @@ def _url_quote(s): # pylint: disable=invalid-name return quoted.replace("%", "%%") +def _get_opentelemetry_values(): + """ + Return the OpenTelemetry Trace and Span IDs if Span ID is set in the + OpenTelemetry execution context. + """ + # Insert the W3C TraceContext generated + _headers = {} + propagator.inject(_headers) + return _headers + + def _generate_opentelemetry_traceparent(span: Span) -> str: meta = {} _version = "00"