diff --git a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py index 91af787c28..8f7618ddd5 100644 --- a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py @@ -32,7 +32,13 @@ from opentelemetry.instrumentation.wsgi import wsgi_getter from opentelemetry.propagate import extract from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import Span, SpanKind, use_span +from opentelemetry.trace import ( + INVALID_SPAN, + Span, + SpanKind, + get_current_span, + use_span, +) from opentelemetry.util.http import get_excluded_urls, get_traced_request_attrs try: @@ -184,11 +190,16 @@ def process_request(self, request): carrier_getter = wsgi_getter collect_request_attributes = wsgi_collect_request_attributes - token = attach(extract(carrier, getter=carrier_getter)) - + token = context = None + span_kind = SpanKind.INTERNAL + if get_current_span() is INVALID_SPAN: + context = extract(request_meta, getter=wsgi_getter) + token = attach(context) + span_kind = SpanKind.SERVER span = self._tracer.start_span( self._get_span_name(request), - kind=SpanKind.SERVER, + context, + kind=span_kind, start_time=request_meta.get( "opentelemetry-instrumentor-django.starttime_key" ), @@ -221,7 +232,8 @@ def process_request(self, request): request.META[self._environ_activation_key] = activation request.META[self._environ_span_key] = span - request.META[self._environ_token] = token + if token: + request.META[self._environ_token] = token if _DjangoMiddleware._otel_request_hook: _DjangoMiddleware._otel_request_hook( # pylint: disable=not-callable diff --git a/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py b/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py index 473a0fb3c0..5c329a1281 100644 --- a/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/tests/test_middleware.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from sys import modules from unittest.mock import Mock, patch from django import VERSION, conf +from django.core.wsgi import get_wsgi_application +from django.core.servers.basehttp import get_internal_wsgi_application from django.http import HttpRequest, HttpResponse -from django.test.client import Client -from django.test.utils import setup_test_environment, teardown_test_environment +from django.test.client import Client, ClientHandler, RequestFactory +from django.test.testcases import SimpleTestCase +from django.test.utils import override_settings, setup_test_environment, teardown_test_environment +from fastapi import applications +from opentelemetry import trace from opentelemetry.instrumentation.django import ( DjangoInstrumentor, @@ -28,6 +34,7 @@ TraceResponsePropagator, set_global_response_propagator, ) +from opentelemetry.instrumentation.wsgi import OpenTelemetryMiddleware from opentelemetry.sdk import resources from opentelemetry.sdk.trace import Span from opentelemetry.sdk.trace.id_generator import RandomIdGenerator @@ -41,6 +48,7 @@ format_trace_id, ) from opentelemetry.util.http import get_excluded_urls, get_traced_request_attrs +from packaging.markers import Op # pylint: disable=import-error from .views import ( @@ -401,10 +409,9 @@ def setUpClass(cls): def setUp(self): super().setUp() setup_test_environment() - resource = resources.Resource.create( - {"resource-key": "resource-value"} - ) - result = self.create_tracer_provider(resource=resource) + _django_instrumentor.instrument() + + result = self.create_tracer_provider() tracer_provider, exporter = result self.exporter = exporter _django_instrumentor.instrument(tracer_provider=tracer_provider) @@ -430,3 +437,67 @@ def test_tracer_provider_traced(self): self.assertEqual( span.resource.attributes["resource-key"], "resource-value" ) + +class TestDjangoWithOtherFramework(SimpleTestCase, TestBase, WsgiTestBase): + @classmethod + def setUpClass(cls): + conf.settings.configure(ROOT_URLCONF=modules[__name__], WSGI_APPLICATION="tests.wsgi.application") + super().setUpClass() + + def setUp(self): + super().setUp() + setup_test_environment() + + # conf.settings.WSGI_APPLICATION = "wsgi.application" + # application = get_internal_wsgi_application() + # application = get_wsgi_application() + # _DjangoMiddleware + # self.application = OpenTelemetryMiddleware(application) + # _django_instrumentor.instrument() + # self.application = OpenTelemetryMiddleware(application, tracer_provider=self.tracer_provider) + # conf.settings.configure(ROOT_URLCONF=modules[__name__], WSGI_APPLICATION="application") + # conf.settings.WSGI_APPLICATION="application" + + + def tearDown(self) -> None: + super().tearDown() + teardown_test_environment() + _django_instrumentor.uninstrument() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + conf.settings = conf.LazySettings() + + + def test_with_another_framework(self): + environ = RequestFactory()._base_environ( + PATH_INFO="/span_name/1234/", + CONTENT_TYPE="text/html; charset=utf-8", + REQUEST_METHOD="GET" + ) + response_data = {} + def start_response(status, headers): + response_data["status"] = status + response_data["headers"] = headers + + result = self.create_tracer_provider() + tracer_provider, exporter = result + self.exporter = exporter + + + _django_instrumentor.instrument(tracer_provider=tracer_provider) + application = get_internal_wsgi_application() + application = OpenTelemetryMiddleware(application, tracer_provider=tracer_provider) + resp = application(environ, start_response) + + # resp = Client().get("/span_name/1234/") + # self.assertEqual(200, resp.status_code) + + # span_list = self.memory_exporter.get_finished_spans() + span_list = self.exporter.get_finished_spans() + # print(span_list) + self.assertEqual(trace.SpanKind.INTERNAL, span_list[0].kind) + + #Below line give me error "index out of the range" as there is only one span created where it should be 2. + self.assertEqual(trace.SpanKind.SERVER, span_list[1].kind)