From 86a5d28a151754cd6d631986ed73c6ff8f688522 Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Wed, 16 Oct 2024 16:09:14 +0200 Subject: [PATCH] httpx: rewrote patching to use wrapt instead of subclassing client --- .../instrumentation/httpx/__init__.py | 298 ++++++++++++------ .../tests/test_httpx_integration.py | 32 +- 2 files changed, 214 insertions(+), 116 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py index 15ee59a183..9ca8f6e31b 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py @@ -192,10 +192,10 @@ async def async_response_hook(span, request, response): """ import logging import typing -from asyncio import iscoroutinefunction from types import TracebackType import httpx +from wrapt import wrap_function_wrapper from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -216,6 +216,7 @@ async def async_response_hook(span, request, response): from opentelemetry.instrumentation.utils import ( http_status_to_status_code, is_http_instrumentation_enabled, + unwrap, ) from opentelemetry.propagate import inject from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE @@ -728,44 +729,183 @@ def _instrument(self, **kwargs): ``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient`` ``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient`` """ - self._original_client = httpx.Client - self._original_async_client = httpx.AsyncClient - request_hook = kwargs.get("request_hook") - response_hook = kwargs.get("response_hook") - async_request_hook = kwargs.get("async_request_hook") - async_response_hook = kwargs.get("async_response_hook") - if callable(request_hook): - _InstrumentedClient._request_hook = request_hook - if callable(async_request_hook) and iscoroutinefunction( - async_request_hook - ): - _InstrumentedAsyncClient._request_hook = async_request_hook - if callable(response_hook): - _InstrumentedClient._response_hook = response_hook - if callable(async_response_hook) and iscoroutinefunction( - async_response_hook - ): - _InstrumentedAsyncClient._response_hook = async_response_hook tracer_provider = kwargs.get("tracer_provider") - _InstrumentedClient._tracer_provider = tracer_provider - _InstrumentedAsyncClient._tracer_provider = tracer_provider - # Intentionally using a private attribute here, see: - # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719 - httpx.Client = httpx._api.Client = _InstrumentedClient - httpx.AsyncClient = _InstrumentedAsyncClient + self._request_hook = kwargs.get("request_hook") + self._response_hook = kwargs.get("response_hook") + self._async_request_hook = kwargs.get("async_request_hook") + self._async_response_hook = kwargs.get("async_response_hook") + + if getattr(self, "__instrumented", False): + print("already instrumented") + return + + _OpenTelemetrySemanticConventionStability._initialize() + self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + self._tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(self._sem_conv_opt_in_mode), + ) + + wrap_function_wrapper( + "httpx", + "HTTPTransport.handle_request", + self._handle_request_wrapper, + ) + wrap_function_wrapper( + "httpx", + "AsyncHTTPTransport.handle_async_request", + self._handle_async_request_wrapper, + ) + + self.__instrumented = True def _uninstrument(self, **kwargs): - httpx.Client = httpx._api.Client = self._original_client - httpx.AsyncClient = self._original_async_client - _InstrumentedClient._tracer_provider = None - _InstrumentedClient._request_hook = None - _InstrumentedClient._response_hook = None - _InstrumentedAsyncClient._tracer_provider = None - _InstrumentedAsyncClient._request_hook = None - _InstrumentedAsyncClient._response_hook = None + import httpx + + unwrap(httpx.HTTPTransport, "handle_request") + unwrap(httpx.AsyncHTTPTransport, "handle_async_request") + + def _handle_request_wrapper(self, wrapped, instance, args, kwargs): + if not is_http_instrumentation_enabled(): + return wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + self._sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with self._tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(self._request_hook): + self._request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + self._sem_conv_opt_in_mode, + ) + if callable(self._response_hook): + self._response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new( + self._sem_conv_opt_in_mode + ): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + + async def _handle_async_request_wrapper( + self, wrapped, instance, args, kwargs + ): + if not is_http_instrumentation_enabled(): + return await wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + self._sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with self._tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(self._async_request_hook): + await self._async_request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = await wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + self._sem_conv_opt_in_mode, + ) + + if callable(self._async_response_hook): + await self._async_response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new( + self._sem_conv_opt_in_mode + ): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response - @staticmethod def instrument_client( + self, client: typing.Union[httpx.Client, httpx.AsyncClient], tracer_provider: TracerProvider = None, request_hook: typing.Union[ @@ -785,67 +925,27 @@ def instrument_client( response_hook: A hook that receives the span, request, and response that is called right before the span ends """ - # pylint: disable=protected-access - if not hasattr(client, "_is_instrumented_by_opentelemetry"): - client._is_instrumented_by_opentelemetry = False - if not client._is_instrumented_by_opentelemetry: - if isinstance(client, httpx.Client): - client._original_transport = client._transport - client._original_mounts = client._mounts.copy() - transport = client._transport or httpx.HTTPTransport() - client._transport = SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) - - if isinstance(client, httpx.AsyncClient): - transport = client._transport or httpx.AsyncHTTPTransport() - client._original_mounts = client._mounts.copy() - client._transport = AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) - else: + if getattr(client, "_is_instrumented_by_opentelemetry", False): _logger.warning( "Attempting to instrument Httpx client while already instrumented" ) + return + + if hasattr(client._transport, "handle_request"): + wrap_function_wrapper( + client._transport, + "handle_request", + self._handle_request_wrapper, + ) + client._is_instrumented_by_opentelemetry = True + if hasattr(client._transport, "handle_async_request"): + wrap_function_wrapper( + client._transport, + "handle_async_request", + self._handle_async_request_wrapper, + ) + client._is_instrumented_by_opentelemetry = True @staticmethod def uninstrument_client( @@ -856,15 +956,9 @@ def uninstrument_client( Args: client: The httpx Client or AsyncClient instance """ - if hasattr(client, "_original_transport"): - client._transport = client._original_transport - del client._original_transport + if hasattr(client._transport, "handle_request"): + unwrap(client._transport, "handle_request") + client._is_instrumented_by_opentelemetry = False + elif hasattr(client._transport, "handle_async_request"): + unwrap(client._transport, "handle_async_request") client._is_instrumented_by_opentelemetry = False - if hasattr(client, "_original_mounts"): - client._mounts = client._original_mounts.copy() - del client._original_mounts - else: - _logger.warning( - "Attempting to uninstrument Httpx " - "client while already uninstrumented" - ) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py index 27535800cb..c6f15b559e 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py @@ -166,11 +166,14 @@ def setUp(self): ) ) + HTTPXClientInstrumentor().instrument() + # pylint: disable=invalid-name def tearDown(self): super().tearDown() self.env_patch.stop() respx.stop() + HTTPXClientInstrumentor().uninstrument() def assert_span( self, exporter: "SpanExporter" = None, num_spans: int = 1 @@ -743,6 +746,8 @@ def setUp(self): super().setUp() HTTPXClientInstrumentor().instrument() self.client = self.create_client() + + def tearDown(self): HTTPXClientInstrumentor().uninstrument() def create_proxy_mounts(self): @@ -769,6 +774,7 @@ def test_custom_tracer_provider(self): result = self.create_tracer_provider(resource=resource) tracer_provider, exporter = result + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=tracer_provider ) @@ -787,6 +793,7 @@ def test_response_hook(self): else "response_hook" ) response_hook_kwargs = {response_hook_key: self.response_hook} + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, **response_hook_kwargs, @@ -808,6 +815,7 @@ def test_response_hook(self): HTTPXClientInstrumentor().uninstrument() def test_response_hook_sync_async_kwargs(self): + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, response_hook=_response_hook, @@ -819,7 +827,7 @@ def test_response_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -836,6 +844,7 @@ def test_request_hook(self): else "request_hook" ) request_hook_kwargs = {request_hook_key: self.request_hook} + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, **request_hook_kwargs, @@ -849,6 +858,7 @@ def test_request_hook(self): HTTPXClientInstrumentor().uninstrument() def test_request_hook_sync_async_kwargs(self): + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, request_hook=_request_hook, @@ -863,6 +873,7 @@ def test_request_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().uninstrument() def test_request_hook_no_span_update(self): + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument( tracer_provider=self.tracer_provider, request_hook=self.no_update_request_hook, @@ -876,6 +887,7 @@ def test_request_hook_no_span_update(self): HTTPXClientInstrumentor().uninstrument() def test_not_recording(self): + HTTPXClientInstrumentor().uninstrument() with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span: HTTPXClientInstrumentor().instrument( tracer_provider=trace.NoOpTracerProvider() @@ -894,6 +906,7 @@ def test_not_recording(self): HTTPXClientInstrumentor().uninstrument() def test_suppress_instrumentation_new_client(self): + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().instrument() with suppress_http_instrumentation(): client = self.create_client() @@ -904,6 +917,7 @@ def test_suppress_instrumentation_new_client(self): HTTPXClientInstrumentor().uninstrument() def test_instrument_client(self): + HTTPXClientInstrumentor().uninstrument() client = self.create_client() HTTPXClientInstrumentor().instrument_client(client) result = self.perform_request(self.URL, client=client) @@ -911,8 +925,6 @@ def test_instrument_client(self): self.assert_span(num_spans=1) def test_instrumentation_without_client(self): - - HTTPXClientInstrumentor().instrument() results = [ httpx.get(self.URL), httpx.request("GET", self.URL), @@ -930,10 +942,7 @@ def test_instrumentation_without_client(self): self.URL, ) - HTTPXClientInstrumentor().uninstrument() - def test_uninstrument(self): - HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().uninstrument() client = self.create_client() result = self.perform_request(self.URL, client=client) @@ -943,6 +952,7 @@ def test_uninstrument(self): self.assert_span(num_spans=0) def test_uninstrument_client(self): + HTTPXClientInstrumentor().uninstrument() HTTPXClientInstrumentor().uninstrument_client(self.client) result = self.perform_request(self.URL) @@ -951,7 +961,6 @@ def test_uninstrument_client(self): self.assert_span(num_spans=0) def test_uninstrument_new_client(self): - HTTPXClientInstrumentor().instrument() client1 = self.create_client() HTTPXClientInstrumentor().uninstrument_client(client1) @@ -974,7 +983,6 @@ def test_uninstrument_new_client(self): def test_instrument_proxy(self): proxy_mounts = self.create_proxy_mounts() - HTTPXClientInstrumentor().instrument() client = self.create_client(mounts=proxy_mounts) self.perform_request(self.URL, client=client) self.assert_span(num_spans=1) @@ -983,9 +991,9 @@ def test_instrument_proxy(self): 2, (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client_with_proxy(self): + HTTPXClientInstrumentor().uninstrument() proxy_mounts = self.create_proxy_mounts() client = self.create_client(mounts=proxy_mounts) self.assert_proxy_mounts( @@ -1006,7 +1014,6 @@ def test_instrument_client_with_proxy(self): def test_uninstrument_client_with_proxy(self): proxy_mounts = self.create_proxy_mounts() - HTTPXClientInstrumentor().instrument() client = self.create_client(mounts=proxy_mounts) self.assert_proxy_mounts( client._mounts.values(), @@ -1069,7 +1076,7 @@ def create_client( transport: typing.Optional[SyncOpenTelemetryTransport] = None, **kwargs, ): - return httpx.Client(transport=transport, **kwargs) + return httpx.Client(**kwargs) def perform_request( self, @@ -1189,10 +1196,7 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() - self.client = self.create_client() self.client2 = self.create_client() - HTTPXClientInstrumentor().uninstrument() def create_client( self,