Skip to content

Commit

Permalink
httpx: rewrote patching to use wrapt instead of subclassing client
Browse files Browse the repository at this point in the history
  • Loading branch information
xrmx committed Oct 16, 2024
1 parent d7d7e96 commit 86a5d28
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ async def async_response_hook(span, request, response):
"""
import logging
import typing
from asyncio import iscoroutinefunction
from types import TracebackType

import httpx
from wrapt import wrap_function_wrapper

from opentelemetry.instrumentation._semconv import (
_get_schema_url,
Expand All @@ -216,6 +216,7 @@ async def async_response_hook(span, request, response):
from opentelemetry.instrumentation.utils import (
http_status_to_status_code,
is_http_instrumentation_enabled,
unwrap,
)
from opentelemetry.propagate import inject
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
Expand Down Expand Up @@ -728,44 +729,183 @@ def _instrument(self, **kwargs):
``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient``
``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient``
"""
self._original_client = httpx.Client
self._original_async_client = httpx.AsyncClient
request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook")
async_request_hook = kwargs.get("async_request_hook")
async_response_hook = kwargs.get("async_response_hook")
if callable(request_hook):
_InstrumentedClient._request_hook = request_hook
if callable(async_request_hook) and iscoroutinefunction(
async_request_hook
):
_InstrumentedAsyncClient._request_hook = async_request_hook
if callable(response_hook):
_InstrumentedClient._response_hook = response_hook
if callable(async_response_hook) and iscoroutinefunction(
async_response_hook
):
_InstrumentedAsyncClient._response_hook = async_response_hook
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedClient._tracer_provider = tracer_provider
_InstrumentedAsyncClient._tracer_provider = tracer_provider
# Intentionally using a private attribute here, see:
# https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719
httpx.Client = httpx._api.Client = _InstrumentedClient
httpx.AsyncClient = _InstrumentedAsyncClient
self._request_hook = kwargs.get("request_hook")
self._response_hook = kwargs.get("response_hook")
self._async_request_hook = kwargs.get("async_request_hook")
self._async_response_hook = kwargs.get("async_response_hook")

if getattr(self, "__instrumented", False):
print("already instrumented")
return

_OpenTelemetrySemanticConventionStability._initialize()
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
_OpenTelemetryStabilitySignalType.HTTP,
)
self._tracer = get_tracer(
__name__,
instrumenting_library_version=__version__,
tracer_provider=tracer_provider,
schema_url=_get_schema_url(self._sem_conv_opt_in_mode),
)

wrap_function_wrapper(
"httpx",
"HTTPTransport.handle_request",
self._handle_request_wrapper,
)
wrap_function_wrapper(
"httpx",
"AsyncHTTPTransport.handle_async_request",
self._handle_async_request_wrapper,
)

self.__instrumented = True

def _uninstrument(self, **kwargs):
httpx.Client = httpx._api.Client = self._original_client
httpx.AsyncClient = self._original_async_client
_InstrumentedClient._tracer_provider = None
_InstrumentedClient._request_hook = None
_InstrumentedClient._response_hook = None
_InstrumentedAsyncClient._tracer_provider = None
_InstrumentedAsyncClient._request_hook = None
_InstrumentedAsyncClient._response_hook = None
import httpx

unwrap(httpx.HTTPTransport, "handle_request")
unwrap(httpx.AsyncHTTPTransport, "handle_async_request")

def _handle_request_wrapper(self, wrapped, instance, args, kwargs):
if not is_http_instrumentation_enabled():
return wrapped(*args, **kwargs)

method, url, headers, stream, extensions = _extract_parameters(
args, kwargs
)
method_original = method.decode()
span_name = _get_default_span_name(method_original)
span_attributes = {}
# apply http client response attributes according to semconv
_apply_request_client_attributes_to_span(
span_attributes,
url,
method_original,
self._sem_conv_opt_in_mode,
)

request_info = RequestInfo(method, url, headers, stream, extensions)

with self._tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(self._request_hook):
self._request_hook(span, request_info)

_inject_propagation_headers(headers, args, kwargs)

try:
response = wrapped(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
response = getattr(exc, "response", None)

if isinstance(response, (httpx.Response, tuple)):
status_code, headers, stream, extensions, http_version = (
_extract_response(response)
)

if span.is_recording():
# apply http client response attributes according to semconv
_apply_response_client_attributes_to_span(
span,
status_code,
http_version,
self._sem_conv_opt_in_mode,
)
if callable(self._response_hook):
self._response_hook(
span,
request_info,
ResponseInfo(status_code, headers, stream, extensions),
)

if exception:
if span.is_recording() and _report_new(
self._sem_conv_opt_in_mode
):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)

return response

async def _handle_async_request_wrapper(
self, wrapped, instance, args, kwargs
):
if not is_http_instrumentation_enabled():
return await wrapped(*args, **kwargs)

method, url, headers, stream, extensions = _extract_parameters(
args, kwargs
)
method_original = method.decode()
span_name = _get_default_span_name(method_original)
span_attributes = {}
# apply http client response attributes according to semconv
_apply_request_client_attributes_to_span(
span_attributes,
url,
method_original,
self._sem_conv_opt_in_mode,
)

request_info = RequestInfo(method, url, headers, stream, extensions)

with self._tracer.start_as_current_span(
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(self._async_request_hook):
await self._async_request_hook(span, request_info)

_inject_propagation_headers(headers, args, kwargs)

try:
response = await wrapped(*args, **kwargs)
except Exception as exc: # pylint: disable=W0703
exception = exc
response = getattr(exc, "response", None)

if isinstance(response, (httpx.Response, tuple)):
status_code, headers, stream, extensions, http_version = (
_extract_response(response)
)

if span.is_recording():
# apply http client response attributes according to semconv
_apply_response_client_attributes_to_span(
span,
status_code,
http_version,
self._sem_conv_opt_in_mode,
)

if callable(self._async_response_hook):
await self._async_response_hook(
span,
request_info,
ResponseInfo(status_code, headers, stream, extensions),
)

if exception:
if span.is_recording() and _report_new(
self._sem_conv_opt_in_mode
):
span.set_attribute(
ERROR_TYPE, type(exception).__qualname__
)
raise exception.with_traceback(exception.__traceback__)

return response

@staticmethod
def instrument_client(
self,
client: typing.Union[httpx.Client, httpx.AsyncClient],
tracer_provider: TracerProvider = None,
request_hook: typing.Union[
Expand All @@ -785,67 +925,27 @@ def instrument_client(
response_hook: A hook that receives the span, request, and response
that is called right before the span ends
"""
# pylint: disable=protected-access
if not hasattr(client, "_is_instrumented_by_opentelemetry"):
client._is_instrumented_by_opentelemetry = False

if not client._is_instrumented_by_opentelemetry:
if isinstance(client, httpx.Client):
client._original_transport = client._transport
client._original_mounts = client._mounts.copy()
transport = client._transport or httpx.HTTPTransport()
client._transport = SyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
client._is_instrumented_by_opentelemetry = True
client._mounts.update(
{
url_pattern: (
SyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
if transport is not None
else transport
)
for url_pattern, transport in client._original_mounts.items()
}
)

if isinstance(client, httpx.AsyncClient):
transport = client._transport or httpx.AsyncHTTPTransport()
client._original_mounts = client._mounts.copy()
client._transport = AsyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
client._is_instrumented_by_opentelemetry = True
client._mounts.update(
{
url_pattern: (
AsyncOpenTelemetryTransport(
transport,
tracer_provider=tracer_provider,
request_hook=request_hook,
response_hook=response_hook,
)
if transport is not None
else transport
)
for url_pattern, transport in client._original_mounts.items()
}
)
else:
if getattr(client, "_is_instrumented_by_opentelemetry", False):
_logger.warning(
"Attempting to instrument Httpx client while already instrumented"
)
return

if hasattr(client._transport, "handle_request"):
wrap_function_wrapper(
client._transport,
"handle_request",
self._handle_request_wrapper,
)
client._is_instrumented_by_opentelemetry = True
if hasattr(client._transport, "handle_async_request"):
wrap_function_wrapper(
client._transport,
"handle_async_request",
self._handle_async_request_wrapper,
)
client._is_instrumented_by_opentelemetry = True

@staticmethod
def uninstrument_client(
Expand All @@ -856,15 +956,9 @@ def uninstrument_client(
Args:
client: The httpx Client or AsyncClient instance
"""
if hasattr(client, "_original_transport"):
client._transport = client._original_transport
del client._original_transport
if hasattr(client._transport, "handle_request"):
unwrap(client._transport, "handle_request")
client._is_instrumented_by_opentelemetry = False
elif hasattr(client._transport, "handle_async_request"):
unwrap(client._transport, "handle_async_request")
client._is_instrumented_by_opentelemetry = False
if hasattr(client, "_original_mounts"):
client._mounts = client._original_mounts.copy()
del client._original_mounts
else:
_logger.warning(
"Attempting to uninstrument Httpx "
"client while already uninstrumented"
)
Loading

0 comments on commit 86a5d28

Please sign in to comment.