Skip to content

Commit

Permalink
Add support for SQLAlchemy Core
Browse files Browse the repository at this point in the history
This support can be used as a replacement, or to compliment the
existing support for SQLAlchemy ORM. Because it patches at the
core level, and all the ORM functions use core it will capture
database communications either though the ORM or though core.
  • Loading branch information
jonathangreen authored and Tyler Hargraves committed Mar 22, 2022
1 parent d30880f commit 2b95726
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aws_xray_sdk/core/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'pymysql',
'psycopg2',
'pg8000',
'sqlalchemy_core',
)

NO_DOUBLE_PATCH = (
Expand All @@ -37,6 +38,7 @@
'pymysql',
'psycopg2',
'pg8000',
'sqlalchemy_core',
)

_PATCHED_MODULES = set()
Expand Down
3 changes: 3 additions & 0 deletions aws_xray_sdk/ext/sqlalchemy_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .patch import patch, unpatch

__all__ = ['patch', 'unpatch']
86 changes: 86 additions & 0 deletions aws_xray_sdk/ext/sqlalchemy_core/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging
import sys
if sys.version_info >= (3, 0, 0):
from urllib.parse import urlparse, uses_netloc
else:
from urlparse import urlparse, uses_netloc

import wrapt

from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core.patcher import _PATCHED_MODULES
from aws_xray_sdk.core.utils import stacktrace
from aws_xray_sdk.ext.util import unwrap


def _sql_meta(instance, args):
try:
metadata = {}
url = urlparse(str(instance.engine.url))
# Add Scheme to uses_netloc or // will be missing from url.
uses_netloc.append(url.scheme)
if url.password is None:
metadata['url'] = url.geturl()
name = url.netloc
else:
# Strip password from URL
host_info = url.netloc.rpartition('@')[-1]
parts = url._replace(netloc='{}@{}'.format(url.username, host_info))
metadata['url'] = parts.geturl()
name = host_info
metadata['user'] = url.username
metadata['database_type'] = instance.engine.name
try:
version = getattr(instance.dialect, '{}_version'.format(instance.engine.driver))
version_str = '.'.join(map(str, version))
metadata['driver_version'] = "{}-{}".format(instance.engine.driver, version_str)
except AttributeError:
metadata['driver_version'] = instance.engine.driver
if instance.dialect.server_version_info is not None:
metadata['database_version'] = '.'.join(map(str, instance.dialect.server_version_info))
if xray_recorder.stream_sql:
metadata['sanitized_query'] = str(args[0])
except Exception:
metadata = None
name = None
logging.getLogger(__name__).exception('Error parsing sql metadata.')
return name, metadata


def _xray_traced_sqlalchemy_execute(wrapped, instance, args, kwargs):
name, sql = _sql_meta(instance, args)
if sql is not None:
subsegment = xray_recorder.begin_subsegment(name, namespace='remote')
else:
subsegment = None
try:
res = wrapped(*args, **kwargs)
except Exception:
if subsegment is not None:
exception = sys.exc_info()[1]
stack = stacktrace.get_stacktrace(limit=xray_recorder._max_trace_back)
subsegment.add_exception(exception, stack)
raise
finally:
if subsegment is not None:
subsegment.set_sql(sql)
xray_recorder.end_subsegment()
return res


def patch():
wrapt.wrap_function_wrapper(
'sqlalchemy.engine.base',
'Connection.execute',
_xray_traced_sqlalchemy_execute
)


def unpatch():
"""
Unpatch any previously patched modules.
This operation is idempotent.
"""
_PATCHED_MODULES.discard('sqlalchemy_core')
import sqlalchemy
unwrap(sqlalchemy.engine.base.Connection, 'execute')
Empty file.
108 changes: 108 additions & 0 deletions tests/ext/sqlalchemy_core/test_sqlalchemy_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import absolute_import

import pytest
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import Insert, Delete

from aws_xray_sdk.core import xray_recorder, patch
from aws_xray_sdk.core.context import Context

Base = declarative_base()


class User(Base):
__tablename__ = 'users'

id = Column(Integer, primary_key=True)
name = Column(String)
fullname = Column(String)
password = Column(String)


@pytest.fixture()
def engine():
"""
Clean up context storage on each test run and begin a segment
so that later subsegment can be attached. After each test run
it cleans up context storage again.
"""
from aws_xray_sdk.ext.sqlalchemy_core import unpatch
patch(('sqlalchemy_core',))
engine = create_engine('sqlite:///:memory:')
xray_recorder.configure(service='test', sampling=False, context=Context())
xray_recorder.begin_segment('name')
Base.metadata.create_all(engine)
xray_recorder.clear_trace_entities()
xray_recorder.begin_segment('name')
yield engine
xray_recorder.clear_trace_entities()
unpatch()


@pytest.fixture()
def connection(engine):
return engine.connect()


@pytest.fixture()
def session(engine):
Session = sessionmaker(bind=engine)
return Session()


def test_all(session):
""" Test calling all() on get all records.
Verify we run the query and return the SQL as metdata"""
session.query(User).all()
assert len(xray_recorder.current_segment().subsegments) == 1
sql_meta = xray_recorder.current_segment().subsegments[0].sql
assert sql_meta['url'] == 'sqlite:///:memory:'
assert sql_meta['sanitized_query'].startswith('SELECT')
assert sql_meta['sanitized_query'].endswith('FROM users')


def test_add(session):
""" Test calling add() on insert a row.
Verify we that we capture trace for the add"""
password = "123456"
john = User(name='John', fullname="John Doe", password=password)
session.add(john)
session.commit()
assert len(xray_recorder.current_segment().subsegments) == 1
sql_meta = xray_recorder.current_segment().subsegments[0].sql
assert sql_meta['sanitized_query'].startswith('INSERT INTO users')
assert password not in sql_meta['sanitized_query']


def test_filter_first(session):
""" Test calling filter().first() on get first filtered records.
Verify we run the query and return the SQL as metdata"""
session.query(User).filter(User.password=="mypassword!").first()
assert len(xray_recorder.current_segment().subsegments) == 1
sql_meta = xray_recorder.current_segment().subsegments[0].sql
assert sql_meta['sanitized_query'].startswith('SELECT')
assert 'FROM users' in sql_meta['sanitized_query']
assert "mypassword!" not in sql_meta['sanitized_query']


def test_connection_add(connection):
password = "123456"
statement = Insert(User).values(name='John', fullname="John Doe", password=password)
connection.execute(statement)
assert len(xray_recorder.current_segment().subsegments) == 1
sql_meta = xray_recorder.current_segment().subsegments[0].sql
assert sql_meta['sanitized_query'].startswith('INSERT INTO users')
assert sql_meta['url'] == 'sqlite:///:memory:'
assert password not in sql_meta['sanitized_query']

def test_connection_query(connection):
password = "123456"
statement = Delete(User).where(User.name == 'John').where(User.password == password)
connection.execute(statement)
assert len(xray_recorder.current_segment().subsegments) == 1
sql_meta = xray_recorder.current_segment().subsegments[0].sql
assert sql_meta['sanitized_query'].startswith('DELETE FROM users')
assert sql_meta['url'] == 'sqlite:///:memory:'
assert password not in sql_meta['sanitized_query']

0 comments on commit 2b95726

Please sign in to comment.