Skip to content

Commit

Permalink
Convert the unittest tests to pytest test style.
Browse files Browse the repository at this point in the history
  • Loading branch information
decko committed May 11, 2023
1 parent 767adcc commit e312bec
Showing 1 changed file with 102 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,124 +12,112 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import contextlib
import typing
import unittest
import urllib.parse
from functools import partial
from unittest import mock

import pytest
import pytest_asyncio
import aiohttp
import aiohttp.test_utils
from http import HTTPMethod, HTTPStatus
from pkg_resources import iter_entry_points
from unittest import mock

from opentelemetry import trace as trace_api
from opentelemetry.test.test_base import TestBase
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.test.test_base import TestBase

from opentelemetry.test.globals_test import (
reset_trace_globals,
)


@pytest.fixture(scope="session")
def tracer():
test_base = TestBase()

tracer_provider, memory_exporter = test_base.create_tracer_provider()

reset_trace_globals()
trace_api.set_tracer_provider(tracer_provider)

yield tracer_provider, memory_exporter

reset_trace_globals()


async def default_handler(request, status=200):
return aiohttp.web.Response(status=status)


@pytest_asyncio.fixture
async def server_fixture(tracer, aiohttp_server):
_, memory_exporter = tracer

AioHttpServerInstrumentor().instrument()

app = aiohttp.web.Application()
app.add_routes(
[aiohttp.web.get("/test-path", default_handler)])

server = await aiohttp_server(app)
yield server, app

memory_exporter.clear()

AioHttpServerInstrumentor().uninstrument()


def test_checking_instrumentor_pkg_installed():
entry_points = iter_entry_points(
"opentelemetry_instrumentor", "aiohttp-server"
)

instrumentor = next(entry_points).load()()
assert (isinstance(instrumentor, AioHttpServerInstrumentor))


@pytest.mark.asyncio
@pytest.mark.parametrize("url, expected_method, expected_status_code", [
("/test-path", HTTPMethod.GET, HTTPStatus.OK),
("/not-found", HTTPMethod.GET, HTTPStatus.NOT_FOUND)
])
async def test_status_code_instrumentation(tracer, server_fixture,
aiohttp_client, url,
expected_method,
expected_status_code):
_, memory_exporter = tracer
server, app = server_fixture

assert len(memory_exporter.get_finished_spans()) == 0

client = await aiohttp_client(server)
await client.get(url)

assert len(memory_exporter.get_finished_spans()) == 1

[span] = memory_exporter.get_finished_spans()

assert expected_method == span.attributes[SpanAttributes.HTTP_METHOD]
assert expected_status_code == span.attributes[SpanAttributes.HTTP_STATUS_CODE]

assert f"http://{server.host}:{server.port}{url}" == span.attributes[
SpanAttributes.HTTP_URL
]


@pytest.mark.skip(reason="Historical purposes. Can't see the reason of this mock.")
def test_not_recording(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
with mock.patch("opentelemetry.trace.get_tracer") as patched:
patched.start_span.return_value = mock_span
# pylint: disable=W0612
# host, port = run_with_test_server(
# self.get_default_request(), self.URL, self.default_handler
# )

def run_with_test_server(
runnable: typing.Callable, url: str, handler: typing.Callable
) -> typing.Tuple[str, int]:
async def do_request():
app = aiohttp.web.Application()
parsed_url = urllib.parse.urlparse(url)
app.add_routes([aiohttp.web.get(parsed_url.path, handler)])
app.add_routes([aiohttp.web.post(parsed_url.path, handler)])
app.add_routes([aiohttp.web.patch(parsed_url.path, handler)])

with contextlib.suppress(aiohttp.ClientError):
async with aiohttp.test_utils.TestServer(app) as server:
netloc = (server.host, server.port)
await server.start_server()
await runnable(server)
return netloc

loop = asyncio.get_event_loop()
return loop.run_until_complete(do_request())


class TestAioHttpServerIntegration(TestBase):
URL = "/test-path"

def setUp(self):
super().setUp()
AioHttpServerInstrumentor().instrument()

def tearDown(self):
super().tearDown()
AioHttpServerInstrumentor().uninstrument()

@staticmethod
# pylint:disable=unused-argument
async def default_handler(request, status=200):
return aiohttp.web.Response(status=status)

def assert_spans(self, num_spans: int):
finished_spans = self.memory_exporter.get_finished_spans()
self.assertEqual(num_spans, len(finished_spans))
if num_spans == 0:
return None
if num_spans == 1:
return finished_spans[0]
return finished_spans

@staticmethod
def get_default_request(url: str = URL):
async def default_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(server) as session:
await session.get(url)

return default_request

def test_instrument(self):
host, port = run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)
span = self.assert_spans(1)
self.assertEqual("GET", span.attributes[SpanAttributes.HTTP_METHOD])
self.assertEqual(
f"http://{host}:{port}/test-path",
span.attributes[SpanAttributes.HTTP_URL],
)
self.assertEqual(200, span.attributes[SpanAttributes.HTTP_STATUS_CODE])

def test_status_codes(self):
error_handler = partial(self.default_handler, status=400)
host, port = run_with_test_server(
self.get_default_request(), self.URL, error_handler
)
span = self.assert_spans(1)
self.assertEqual("GET", span.attributes[SpanAttributes.HTTP_METHOD])
self.assertEqual(
f"http://{host}:{port}/test-path",
span.attributes[SpanAttributes.HTTP_URL],
)
self.assertEqual(400, span.attributes[SpanAttributes.HTTP_STATUS_CODE])

def test_not_recording(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
with mock.patch("opentelemetry.trace.get_tracer"):
# pylint: disable=W0612
host, port = run_with_test_server(
self.get_default_request(), self.URL, self.default_handler
)

self.assertFalse(mock_span.is_recording())
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)


class TestLoadingAioHttpInstrumentor(unittest.TestCase):
def test_loading_instrumentor(self):
entry_points = iter_entry_points(
"opentelemetry_instrumentor", "aiohttp-server"
)

instrumentor = next(entry_points).load()()
self.assertIsInstance(instrumentor, AioHttpServerInstrumentor)
self.assertTrue(patched.start_span.called)
self.assertFalse(mock_span.is_recording())
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

0 comments on commit e312bec

Please sign in to comment.