Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mnbbrown committed Jul 7, 2021
1 parent ed9ca71 commit 8812117
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from opentelemetry.instrumentation.sqlalchemy.engine import (
EngineTracer,
_get_tracer,
_wrap_create_async_engine,
_wrap_create_engine,
)
from opentelemetry.instrumentation.sqlalchemy.package import _instruments
Expand Down Expand Up @@ -88,6 +89,13 @@ def _instrument(self, **kwargs):
"""
_w("sqlalchemy", "create_engine", _wrap_create_engine)
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
if sqlalchemy.__version__.startswith("1.4"):
_w(
"sqlalchemy.ext.asyncio",
"create_async_engine",
_wrap_create_async_engine,
)

if kwargs.get("engine") is not None:
return EngineTracer(
_get_tracer(
Expand All @@ -100,3 +108,5 @@ def _instrument(self, **kwargs):
def _uninstrument(self, **kwargs):
unwrap(sqlalchemy, "create_engine")
unwrap(sqlalchemy.engine, "create_engine")
if sqlalchemy.__version__.startswith("1.4"):
unwrap(sqlalchemy.ext.asyncio, "create_async_engine")
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def _get_tracer(engine, tracer_provider=None):
)


# pylint: disable=unused-argument
def _wrap_create_async_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine.sync_engine)
return engine


# pylint: disable=unused-argument
def _wrap_create_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
Expand Down Expand Up @@ -78,7 +88,9 @@ def _operation_name(self, db_name, statement):
return " ".join(parts)

# pylint: disable=unused-argument
def _before_cur_exec(self, conn, cursor, statement, params, context, executemany):
def _before_cur_exec(
self, conn, cursor, statement, params, context, executemany
):
attrs, found = _get_attributes_from_url(conn.engine.url)
if not found:
attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs)
Expand All @@ -98,15 +110,17 @@ def _before_cur_exec(self, conn, cursor, statement, params, context, executemany
context._span = span

# pylint: disable=unused-argument
def _after_cur_exec(self, conn, cursor, statement, params, context, executemany):
span = getattr(context, '_span', None)
def _after_cur_exec(
self, conn, cursor, statement, params, context, executemany
):
span = getattr(context, "_span", None)
if span is None:
return

span.end()

def _handle_error(self, context):
span = getattr(context.execution_context, '_span', None)
span = getattr(context.execution_context, "_span", None)
if span is None:
return

Expand All @@ -122,7 +136,6 @@ def _handle_error(self, context):
span.end()



def _get_attributes_from_url(url):
"""Set connection tags from the url. return true if successful."""
attrs = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Coroutine
from unittest import mock

from sqlalchemy import create_engine
import sqlalchemy
from sqlalchemy import create_engine

from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.test.test_base import TestBase
import asyncio


def _call_async(coro: Coroutine):
Expand Down Expand Up @@ -48,9 +49,9 @@ def test_trace_integration(self):
def test_async_trace_integration(self):
if sqlalchemy.__version__.startswith("1.3"):
return
from sqlalchemy.ext.asyncio import (
from sqlalchemy.ext.asyncio import ( # pylint: disable-all
create_async_engine,
) # pylint: disable-all
)

engine = create_async_engine("sqlite+aiosqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
Expand Down Expand Up @@ -95,3 +96,20 @@ def test_create_engine_wrapper(self):
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

def test_create_async_engine_wrapper(self):
SQLAlchemyInstrumentor().instrument()
if sqlalchemy.__version__.startswith("1.3"):
return
from sqlalchemy.ext.asyncio import ( # pylint: disable-all
create_async_engine,
)

engine = create_async_engine("sqlite+aiosqlite:///:memory:")
cnx = _call_async(engine.connect())
_call_async(cnx.execute(sqlalchemy.text("SELECT 1 + 1;"))).fetchall()
_call_async(cnx.close())
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,6 @@ def insert_players(session):
close_all_sessions()

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 5 if self.VENDOR not in ["postgresql"] else 3)
self.assertEqual(
len(spans), 5 if self.VENDOR not in ["postgresql"] else 3
)

0 comments on commit 8812117

Please sign in to comment.