Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert #1097 #1660

Merged
merged 3 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def instrument_connection(
Returns:
An instrumented connection.
"""
if isinstance(connection, _TracedConnectionProxy):
if isinstance(connection, wrapt.ObjectProxy):
_logger.warning("Connection already instrumented")
return connection

Expand All @@ -230,8 +230,8 @@ def uninstrument_connection(connection):
Returns:
An uninstrumented connection.
"""
if isinstance(connection, _TracedConnectionProxy):
return connection._connection
if isinstance(connection, wrapt.ObjectProxy):
return connection.__wrapped__

_logger.warning("Connection is not instrumented")
return connection
Expand Down Expand Up @@ -320,22 +320,14 @@ def get_connection_attributes(self, connection):
self.span_attributes[SpanAttributes.NET_PEER_PORT] = port


class _TracedConnectionProxy:
pass


def get_traced_connection_proxy(
connection, db_api_integration, *args, **kwargs
):
# pylint: disable=abstract-method
class TracedConnectionProxy(type(connection), _TracedConnectionProxy):
def __init__(self, connection):
self._connection = connection

def __getattr__(self, name):
return object.__getattribute__(
object.__getattribute__(self, "_connection"), name
)
class TracedConnectionProxy(wrapt.ObjectProxy):
# pylint: disable=unused-argument
def __init__(self, connection, *args, **kwargs):
wrapt.ObjectProxy.__init__(self, connection)

def __getattribute__(self, name):
if object.__getattribute__(self, name):
Expand All @@ -347,16 +339,17 @@ def __getattribute__(self, name):

def cursor(self, *args, **kwargs):
return get_traced_cursor_proxy(
self._connection.cursor(*args, **kwargs), db_api_integration
self.__wrapped__.cursor(*args, **kwargs), db_api_integration
)

# For some reason this is necessary as trying to access the close
# method of self._connection via __getattr__ leads to unexplained
# errors.
def close(self):
self._connection.close()
def __enter__(self):
self.__wrapped__.__enter__()
return self

def __exit__(self, *args, **kwargs):
self.__wrapped__.__exit__(*args, **kwargs)

return TracedConnectionProxy(connection)
return TracedConnectionProxy(connection, *args, **kwargs)


class CursorTracer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,14 @@ def test_callproc(self):

@mock.patch("opentelemetry.instrumentation.dbapi")
def test_wrap_connect(self, mock_dbapi):
dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-")
dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-")
connection = mock_dbapi.connect()
self.assertEqual(mock_dbapi.connect.call_count, 1)
self.assertIsInstance(connection._connection, mock.Mock)
self.assertIsInstance(connection.__wrapped__, mock.Mock)

@mock.patch("opentelemetry.instrumentation.dbapi")
def test_unwrap_connect(self, mock_dbapi):
dbapi.wrap_connect(self.tracer, MockConnectionEmpty(), "connect", "-")
dbapi.wrap_connect(self.tracer, mock_dbapi, "connect", "-")
connection = mock_dbapi.connect()
self.assertEqual(mock_dbapi.connect.call_count, 1)

Expand All @@ -342,21 +342,19 @@ def test_unwrap_connect(self, mock_dbapi):
self.assertIsInstance(connection, mock.Mock)

def test_instrument_connection(self):
connection = MockConnectionEmpty()
connection = mock.Mock()
# Avoid get_attributes failing because can't concatenate mock
# pylint: disable=attribute-defined-outside-init
connection.database = "-"
connection2 = dbapi.instrument_connection(self.tracer, connection, "-")
self.assertIs(connection2._connection, connection)
self.assertIs(connection2.__wrapped__, connection)

def test_uninstrument_connection(self):
connection = MockConnectionEmpty()
connection = mock.Mock()
# Set connection.database to avoid a failure because mock can't
# be concatenated
# pylint: disable=attribute-defined-outside-init
connection.database = "-"
connection2 = dbapi.instrument_connection(self.tracer, connection, "-")
self.assertIs(connection2._connection, connection)
self.assertIs(connection2.__wrapped__, connection)

connection3 = dbapi.uninstrument_connection(connection2)
self.assertIs(connection3, connection)
Expand All @@ -372,12 +370,10 @@ def mock_connect(*args, **kwargs):
server_host = kwargs.get("server_host")
server_port = kwargs.get("server_port")
user = kwargs.get("user")
return MockConnectionWithAttributes(
database, server_port, server_host, user
)
return MockConnection(database, server_port, server_host, user)


class MockConnectionWithAttributes:
class MockConnection:
def __init__(self, database, server_port, server_host, user):
self.database = database
self.server_port = server_port
Expand Down Expand Up @@ -410,7 +406,3 @@ def executemany(self, query, params=None, throw_exception=False):
def callproc(self, query, params=None, throw_exception=False):
if throw_exception:
raise Exception("Test Exception")


class MockConnectionEmpty:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock, patch
from unittest import mock

import mysql.connector

Expand All @@ -23,15 +23,6 @@
from opentelemetry.test.test_base import TestBase


def mock_connect(*args, **kwargs):
class MockConnection:
def cursor(self):
# pylint: disable=no-self-use
return Mock()

return MockConnection()


def connect_and_execute_query():
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
Expand All @@ -47,9 +38,9 @@ def tearDown(self):
with self.disable_logging():
MySQLInstrumentor().uninstrument()

@patch("mysql.connector.connect", new=mock_connect)
@mock.patch("mysql.connector.connect")
# pylint: disable=unused-argument
def test_instrumentor(self):
def test_instrumentor(self, mock_connect):
MySQLInstrumentor().instrument()

connect_and_execute_query()
Expand All @@ -71,8 +62,9 @@ def test_instrumentor(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@patch("mysql.connector.connect", new=mock_connect)
def test_custom_tracer_provider(self):
@mock.patch("mysql.connector.connect")
# pylint: disable=unused-argument
def test_custom_tracer_provider(self, mock_connect):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
Expand All @@ -86,9 +78,9 @@ def test_custom_tracer_provider(self):

self.assertIs(span.resource, resource)

@patch("mysql.connector.connect", new=mock_connect)
@mock.patch("mysql.connector.connect")
# pylint: disable=unused-argument
def test_instrument_connection(self):
def test_instrument_connection(self, mock_connect):
cnx, query = connect_and_execute_query()

spans_list = self.memory_exporter.get_finished_spans()
Expand All @@ -101,18 +93,18 @@ def test_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@patch("mysql.connector.connect", new=mock_connect)
def test_instrument_connection_no_op_tracer_provider(self):
@mock.patch("mysql.connector.connect")
def test_instrument_connection_no_op_tracer_provider(self, mock_connect):
tracer_provider = trace_api.NoOpTracerProvider()
MySQLInstrumentor().instrument(tracer_provider=tracer_provider)
connect_and_execute_query()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 0)

@patch("mysql.connector.connect", new=mock_connect)
@mock.patch("mysql.connector.connect")
# pylint: disable=unused-argument
def test_uninstrument_connection(self):
def test_uninstrument_connection(self, mock_connect):
MySQLInstrumentor().instrument()
cnx, query = connect_and_execute_query()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock, patch
from unittest import mock

import pymysql

Expand All @@ -22,24 +22,15 @@
from opentelemetry.test.test_base import TestBase


def mock_connect(*args, **kwargs):
class MockConnection:
def cursor(self):
# pylint: disable=no-self-use
return Mock()

return MockConnection()


class TestPyMysqlIntegration(TestBase):
def tearDown(self):
super().tearDown()
with self.disable_logging():
PyMySQLInstrumentor().uninstrument()

@patch("pymysql.connect", new=mock_connect)
@mock.patch("pymysql.connect")
# pylint: disable=unused-argument
def test_instrumentor(self):
def test_instrumentor(self, mock_connect):
PyMySQLInstrumentor().instrument()

cnx = pymysql.connect(database="test")
Expand Down Expand Up @@ -67,9 +58,9 @@ def test_instrumentor(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@patch("pymysql.connect", new=mock_connect)
@mock.patch("pymysql.connect")
# pylint: disable=unused-argument
def test_custom_tracer_provider(self):
def test_custom_tracer_provider(self, mock_connect):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result
Expand All @@ -87,9 +78,9 @@ def test_custom_tracer_provider(self):

self.assertIs(span.resource, resource)

@patch("pymysql.connect", new=mock_connect)
@mock.patch("pymysql.connect")
# pylint: disable=unused-argument
def test_instrument_connection(self):
def test_instrument_connection(self, mock_connect):
cnx = pymysql.connect(database="test")
query = "SELECT * FROM test"
cursor = cnx.cursor()
Expand All @@ -105,9 +96,9 @@ def test_instrument_connection(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

@patch("pymysql.connect", new=mock_connect)
@mock.patch("pymysql.connect")
# pylint: disable=unused-argument
def test_uninstrument_connection(self):
def test_uninstrument_connection(self, mock_connect):
PyMySQLInstrumentor().instrument()
cnx = pymysql.connect(database="test")
query = "SELECT * FROM test"
Expand Down