Skip to content

Commit

Permalink
ext: Use TestBase (#586)
Browse files Browse the repository at this point in the history
Update tests to use TestBase as described on #303.

Co-authored-by: Yusuke Tsutsumi <yusuke@tsutsumi.io>
Co-authored-by: Chris Kleinknecht <libc@google.com>
  • Loading branch information
3 people authored Apr 20, 2020
1 parent 920b4b4 commit 4115c1b
Show file tree
Hide file tree
Showing 13 changed files with 143 additions and 230 deletions.
4 changes: 4 additions & 0 deletions ext/opentelemetry-ext-dbapi/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,9 @@ install_requires =
opentelemetry-api == 0.7.dev0
wrapt >= 1.0.0, < 2.0.0

[options.extras_require]
test =
opentelemetry-test == 0.7.dev0

[options.packages.find]
where = src
108 changes: 38 additions & 70 deletions ext/opentelemetry-ext-dbapi/tests/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest import mock

from opentelemetry import trace as trace_api
from opentelemetry.ext.dbapi import DatabaseApiIntegration
from opentelemetry.test.test_base import TestBase


class TestDBApiIntegration(unittest.TestCase):
class TestDBApiIntegration(TestBase):
def setUp(self):
self.tracer = trace_api.DefaultTracer()
self.span = MockSpan()
self.start_current_span_patcher = mock.patch.object(
self.tracer,
"start_as_current_span",
autospec=True,
spec_set=True,
return_value=self.span,
)

self.start_as_current_span = self.start_current_span_patcher.start()

def tearDown(self):
self.start_current_span_patcher.stop()
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)

def test_span_succeeded(self):
connection_props = {
Expand All @@ -57,28 +43,25 @@ def test_span_succeeded(self):
)
cursor = mock_connection.cursor()
cursor.execute("Test query", ("param1Value", False))
self.assertTrue(self.start_as_current_span.called)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.name, "testcomponent.testdatabase")
self.assertIs(span.kind, trace_api.SpanKind.CLIENT)

self.assertEqual(span.attributes["component"], "testcomponent")
self.assertEqual(span.attributes["db.type"], "testtype")
self.assertEqual(span.attributes["db.instance"], "testdatabase")
self.assertEqual(span.attributes["db.statement"], "Test query")
self.assertEqual(
self.start_as_current_span.call_args[0][0],
"testcomponent.testdatabase",
)
self.assertIs(
self.start_as_current_span.call_args[1]["kind"],
trace_api.SpanKind.CLIENT,
)
self.assertEqual(self.span.attributes["component"], "testcomponent")
self.assertEqual(self.span.attributes["db.type"], "testtype")
self.assertEqual(self.span.attributes["db.instance"], "testdatabase")
self.assertEqual(self.span.attributes["db.statement"], "Test query")
self.assertEqual(
self.span.attributes["db.statement.parameters"],
span.attributes["db.statement.parameters"],
"('param1Value', False)",
)
self.assertEqual(self.span.attributes["db.user"], "testuser")
self.assertEqual(self.span.attributes["net.peer.name"], "testhost")
self.assertEqual(self.span.attributes["net.peer.port"], 123)
self.assertEqual(span.attributes["db.user"], "testuser")
self.assertEqual(span.attributes["net.peer.name"], "testhost")
self.assertEqual(span.attributes["net.peer.port"], 123)
self.assertIs(
self.span.status.canonical_code,
span.status.canonical_code,
trace_api.status.StatusCanonicalCode.OK,
)

Expand All @@ -88,17 +71,18 @@ def test_span_failed(self):
mock_connect, {}, {}
)
cursor = mock_connection.cursor()
try:
with self.assertRaises(Exception):
cursor.execute("Test query", throw_exception=True)
except Exception: # pylint: disable=broad-except
self.assertEqual(
self.span.attributes["db.statement"], "Test query"
)
self.assertIs(
self.span.status.canonical_code,
trace_api.status.StatusCanonicalCode.UNKNOWN,
)
self.assertEqual(self.span.status.description, "Test Exception")

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.attributes["db.statement"], "Test query")
self.assertIs(
span.status.canonical_code,
trace_api.status.StatusCanonicalCode.UNKNOWN,
)
self.assertEqual(span.status.description, "Test Exception")

def test_executemany(self):
db_integration = DatabaseApiIntegration(self.tracer, "testcomponent")
Expand All @@ -107,8 +91,10 @@ def test_executemany(self):
)
cursor = mock_connection.cursor()
cursor.executemany("Test query")
self.assertTrue(self.start_as_current_span.called)
self.assertEqual(self.span.attributes["db.statement"], "Test query")
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(span.attributes["db.statement"], "Test query")

def test_callproc(self):
db_integration = DatabaseApiIntegration(self.tracer, "testcomponent")
Expand All @@ -117,9 +103,11 @@ def test_callproc(self):
)
cursor = mock_connection.cursor()
cursor.callproc("Test stored procedure")
self.assertTrue(self.start_as_current_span.called)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertEqual(
self.span.attributes["db.statement"], "Test stored procedure"
span.attributes["db.statement"], "Test stored procedure"
)


Expand Down Expand Up @@ -159,23 +147,3 @@ def executemany(self, query, params=None, throw_exception=False):
def callproc(self, query, params=None, throw_exception=False):
if throw_exception:
raise Exception("Test Exception")


class MockSpan:
def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
return False

def __init__(self):
self.status = None
self.name = ""
self.kind = trace_api.SpanKind.INTERNAL
self.attributes = {}

def set_attribute(self, key, value):
self.attributes[key] = value

def set_status(self, status):
self.status = status
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,12 @@

import os
import time
import unittest

import mysql.connector

from opentelemetry import trace as trace_api
from opentelemetry.ext.mysql import trace_integration
from opentelemetry.sdk.trace import Tracer, TracerProvider
from opentelemetry.sdk.trace.export import SimpleExportSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.test.test_base import TestBase

MYSQL_USER = os.getenv("MYSQL_USER ", "testuser")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD ", "testpassword")
Expand All @@ -33,16 +28,13 @@
MYSQL_DB_NAME = os.getenv("MYSQL_DB_NAME ", "opentelemetry-tests")


class TestFunctionalMysql(unittest.TestCase):
class TestFunctionalMysql(TestBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._connection = None
cls._cursor = None
cls._tracer_provider = TracerProvider()
cls._tracer = Tracer(cls._tracer_provider, None)
cls._span_exporter = InMemorySpanExporter()
cls._span_processor = SimpleExportSpanProcessor(cls._span_exporter)
cls._tracer_provider.add_span_processor(cls._span_processor)
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
cls._connection = mysql.connector.connect(
user=MYSQL_USER,
Expand All @@ -58,11 +50,8 @@ def tearDownClass(cls):
if cls._connection:
cls._connection.close()

def setUp(self):
self._span_exporter.clear()

def validate_spans(self):
spans = self._span_exporter.get_finished_spans()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
for span in spans:
if span.name == "rootSpan":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,12 @@

import os
import time
import unittest

import psycopg2

from opentelemetry import trace as trace_api
from opentelemetry.ext.psycopg2 import trace_integration
from opentelemetry.sdk.trace import Tracer, TracerProvider
from opentelemetry.sdk.trace.export import SimpleExportSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.test.test_base import TestBase

POSTGRES_HOST = os.getenv("POSTGRESQL_HOST ", "localhost")
POSTGRES_PORT = int(os.getenv("POSTGRESQL_PORT ", "5432"))
Expand All @@ -33,16 +28,13 @@
POSTGRES_USER = os.getenv("POSTGRESQL_HOST ", "testuser")


class TestFunctionalPsycopg(unittest.TestCase):
class TestFunctionalPsycopg(TestBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._connection = None
cls._cursor = None
cls._tracer_provider = TracerProvider()
cls._tracer = Tracer(cls._tracer_provider, None)
cls._span_exporter = InMemorySpanExporter()
cls._span_processor = SimpleExportSpanProcessor(cls._span_exporter)
cls._tracer_provider.add_span_processor(cls._span_processor)
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
cls._connection = psycopg2.connect(
dbname=POSTGRES_DB_NAME,
Expand All @@ -61,11 +53,8 @@ def tearDownClass(cls):
if cls._connection:
cls._connection.close()

def setUp(self):
self._span_exporter.clear()

def validate_spans(self):
spans = self._span_exporter.get_finished_spans()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
for span in spans:
if span.name == "rootSpan":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,33 @@
# limitations under the License.

import os
import typing
import unittest

from pymongo import MongoClient

from opentelemetry import trace as trace_api
from opentelemetry.ext.pymongo import trace_integration
from opentelemetry.sdk.trace import Span, Tracer, TracerProvider
from opentelemetry.sdk.trace.export import SimpleExportSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.test.test_base import TestBase

MONGODB_HOST = os.getenv("MONGODB_HOST ", "localhost")
MONGODB_PORT = int(os.getenv("MONGODB_PORT ", "27017"))
MONGODB_DB_NAME = os.getenv("MONGODB_DB_NAME ", "opentelemetry-tests")
MONGODB_COLLECTION_NAME = "test"


class TestFunctionalPymongo(unittest.TestCase):
class TestFunctionalPymongo(TestBase):
@classmethod
def setUpClass(cls):
cls._tracer_provider = TracerProvider()
cls._tracer = Tracer(cls._tracer_provider, None)
cls._span_exporter = InMemorySpanExporter()
cls._span_processor = SimpleExportSpanProcessor(cls._span_exporter)
cls._tracer_provider.add_span_processor(cls._span_processor)
super().setUpClass()
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
client = MongoClient(
MONGODB_HOST, MONGODB_PORT, serverSelectionTimeoutMS=2000
)
db = client[MONGODB_DB_NAME]
cls._collection = db[MONGODB_COLLECTION_NAME]

def setUp(self):
self._span_exporter.clear()

def validate_spans(self):
spans = self._span_exporter.get_finished_spans()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
for span in spans:
if span.name == "rootSpan":
Expand Down
5 changes: 5 additions & 0 deletions ext/opentelemetry-ext-grpc/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,10 @@ install_requires =
opentelemetry-api == 0.7.dev0
grpcio ~= 1.27

[options.extras_require]
test =
opentelemetry-test == 0.7.dev0
opentelemetry-sdk == 0.7.dev0

[options.packages.find]
where = src
27 changes: 13 additions & 14 deletions ext/opentelemetry-ext-grpc/tests/test_server_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
# pylint:disable=no-self-use

import threading
import unittest
from concurrent import futures
from contextlib import contextmanager
from unittest import mock

import grpc

from opentelemetry import trace
from opentelemetry.ext.grpc import server_interceptor
from opentelemetry.ext.grpc.grpcext import intercept_server
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.test.test_base import TestBase


class UnaryUnaryMethodHandler(grpc.RpcMethodHandler):
Expand All @@ -49,18 +47,16 @@ def service(self, handler_call_details):
return UnaryUnaryMethodHandler(self._unary_unary_handler)


class TestOpenTelemetryServerInterceptor(unittest.TestCase):
class TestOpenTelemetryServerInterceptor(TestBase):
def setUp(self):
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)

def test_create_span(self):
"""Check that the interceptor wraps calls with spans server-side."""

@contextmanager
def mock_start_as_current_span(*args, **kwargs):
yield mock.Mock(spec=trace.Span)

# Intercept gRPC calls...
tracer = mock.Mock(spec=trace.Tracer)
tracer.start_as_current_span.side_effect = mock_start_as_current_span
interceptor = server_interceptor(tracer)
interceptor = server_interceptor(self.tracer)

# No-op RPC handler
def handler(request, context):
Expand All @@ -84,9 +80,12 @@ def handler(request, context):
finally:
server.stop(None)

tracer.start_as_current_span.assert_called_once_with(
name="", kind=trace.SpanKind.SERVER
)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

self.assertEqual(span.name, "")
self.assertIs(span.kind, trace.SpanKind.SERVER)

def test_span_lifetime(self):
"""Check that the span is active for the duration of the call."""
Expand Down
Loading

0 comments on commit 4115c1b

Please sign in to comment.