Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ability to execute precompiled sqlalchemy queries #294

Merged
merged 4 commits into from
Jun 3, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions aiomysql/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import weakref

from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.dml import UpdateBase
from sqlalchemy.sql.ddl import DDLElement

Expand All @@ -16,23 +15,19 @@

class SAConnection:

def __init__(self, connection, engine):
def __init__(self, connection, engine, compiled_cache=None):
self._connection = connection
self._transaction = None
self._savepoint_seq = 0
self._weak_results = weakref.WeakSet()
self._engine = engine
self._dialect = engine.dialect

@property
def engine(self):
return self._engine
self._compiled_cache = compiled_cache

def execute(self, query, *multiparams, **params):
"""Executes a SQL query with optional parameters.

query - a SQL query string or any sqlalchemy expression
(optionally it could be compiled).
query - a SQL query string or any sqlalchemy expression.

*multiparams/**params - represent bound parameter values to be
used in the execution. Typically, the format is a dictionary
Expand Down Expand Up @@ -79,15 +74,24 @@ async def _execute(self, query, *multiparams, **params):

result_map = None

compiled = None
if isinstance(query, SQLCompiler):
compiled = query

if isinstance(query, str):
await cursor.execute(query, dp or None)
elif compiled or isinstance(query, ClauseElement):
compiled = compiled or query.compile(dialect=self._dialect)
# parameters = compiled.params
elif isinstance(query, ClauseElement):
if self._compiled_cache is not None:
key = (self._dialect, query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need self. _dialect as a part of the key, since it is always the same object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, a kind of overkill)

compiled = self._compiled_cache.get(key)
if not compiled:
compiled = query.compile(dialect=self._dialect)
if (
dp and dp.keys() == compiled.params.keys()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the formatting looks weird

or
not (dp or compiled.params)
):
# we only want queries with bound params in cache
self._compiled_cache[key] = compiled
else:
compiled = query.compile(dialect=self._dialect)

if not isinstance(query, DDLElement):
if dp and isinstance(dp, (list, tuple)):
if isinstance(query, UpdateBase):
Expand Down
16 changes: 10 additions & 6 deletions aiomysql/sa/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@


def create_engine(minsize=1, maxsize=10, loop=None,
dialect=_dialect, pool_recycle=-1, **kwargs):
dialect=_dialect, pool_recycle=-1, compiled_cache=None,
**kwargs):
"""A coroutine for Engine creation.
Returns Engine instance with embedded connection pool.
The pool has *minsize* opened connections to PostgreSQL server.
"""
coro = _create_engine(minsize=minsize, maxsize=maxsize, loop=loop,
dialect=dialect, pool_recycle=pool_recycle, **kwargs)
dialect=dialect, pool_recycle=pool_recycle,
compiled_cache=compiled_cache, **kwargs)
compatible_cursor_classes = [Cursor]
# Without provided kwarg, default is default cursor from Connection class
if kwargs.get('cursorclass', Cursor) not in compatible_cursor_classes:
Expand All @@ -38,7 +40,8 @@ def create_engine(minsize=1, maxsize=10, loop=None,


async def _create_engine(minsize=1, maxsize=10, loop=None,
dialect=_dialect, pool_recycle=-1, **kwargs):
dialect=_dialect, pool_recycle=-1,
compiled_cache=None, **kwargs):

if loop is None:
loop = asyncio.get_event_loop()
Expand All @@ -47,7 +50,7 @@ async def _create_engine(minsize=1, maxsize=10, loop=None,
pool_recycle=pool_recycle, **kwargs)
conn = await pool.acquire()
try:
return Engine(dialect, pool, **kwargs)
return Engine(dialect, pool, compiled_cache=compiled_cache, **kwargs)
finally:
pool.release(conn)

Expand All @@ -61,9 +64,10 @@ class Engine:
create_engine coroutine.
"""

def __init__(self, dialect, pool, **kwargs):
def __init__(self, dialect, pool, compiled_cache=None, **kwargs):
self._dialect = dialect
self._pool = pool
self._compiled_cache = compiled_cache
self._conn_kw = kwargs

@property
Expand Down Expand Up @@ -124,7 +128,7 @@ def acquire(self):

async def _acquire(self):
raw = await self._pool.acquire()
conn = SAConnection(raw, self)
conn = SAConnection(raw, self, compiled_cache=self._compiled_cache)
return conn

def release(self, conn):
Expand Down
138 changes: 138 additions & 0 deletions tests/sa/test_sa_compiled_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import asyncio
from aiomysql import sa
from sqlalchemy import bindparam

import os
import unittest

from sqlalchemy import MetaData, Table, Column, Integer, String

meta = MetaData()
tbl = Table('sa_tbl_cache_test', meta,
Column('id', Integer, nullable=False,
primary_key=True),
Column('val', String(255)))


class TestCompiledCache(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.host = os.environ.get('MYSQL_HOST', 'localhost')
self.port = int(os.environ.get('MYSQL_PORT', 3306))
self.user = os.environ.get('MYSQL_USER', 'root')
self.db = os.environ.get('MYSQL_DB', 'test_pymysql')
self.password = os.environ.get('MYSQL_PASSWORD', '')
self.engine = self.loop.run_until_complete(self.make_engine())
self.loop.run_until_complete(self.start())

def tearDown(self):
self.engine.terminate()
self.loop.run_until_complete(self.engine.wait_closed())
self.loop.close()

async def make_engine(self, **kwargs):
return (await sa.create_engine(db=self.db,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
loop=self.loop,
minsize=10,
**kwargs))

async def start(self):
async with self.engine.acquire() as conn:
tx = await conn.begin()
await conn.execute("DROP TABLE IF EXISTS "
"sa_tbl_cache_test")
await conn.execute("CREATE TABLE sa_tbl_cache_test"
"(id serial, val varchar(255))")
await conn.execute(tbl.insert().values(val='some_val_1'))
await conn.execute(tbl.insert().values(val='some_val_2'))
await conn.execute(tbl.insert().values(val='some_val_3'))
await tx.commit()

def test_cache(self):
async def go():
cache = dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pure dict is not option for production system, since it has unbounded size. Should we also provide basic cache implementation also? something like LRU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is up to implementor. LRU could be sufficient in most cases
BTW, sqlalchemy also does not provide any defaults - its only demand is that cache is some dict-like object

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense

engine = await self.make_engine(compiled_cache=cache)
async with engine.acquire() as conn:
# check select with params not added to cache
q = tbl.select().where(tbl.c.val == 'some_val_1')
cursor = await conn.execute(q)
row = await cursor.fetchone()
self.assertEqual('some_val_1', row.val)
self.assertEqual(0, len(cache))

# check select with bound params added to cache
select_by_val = tbl.select().where(
tbl.c.val == bindparam('value')
)
cursor = await conn.execute(
select_by_val, {'value': 'some_val_3'}
)
row = await cursor.fetchone()
self.assertEqual('some_val_3', row.val)
self.assertEqual(1, len(cache))

cursor = await conn.execute(
select_by_val, value='some_val_2'
)
row = await cursor.fetchone()
self.assertEqual('some_val_2', row.val)
self.assertEqual(1, len(cache))

select_all = tbl.select()
cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
self.assertEqual(3, len(rows))
self.assertEqual(2, len(cache))

# check insert with bound params not added to cache
await conn.execute(tbl.insert().values(val='some_val_4'))
self.assertEqual(2, len(cache))

# check insert with bound params added to cache
q = tbl.insert().values(val=bindparam('value'))
await conn.execute(q, value='some_val_5')
self.assertEqual(3, len(cache))

await conn.execute(q, value='some_val_6')
self.assertEqual(3, len(cache))

await conn.execute(q, {'value': 'some_val_7'})
self.assertEqual(3, len(cache))

cursor = await conn.execute(select_all)
rows = await cursor.fetchall()
self.assertEqual(7, len(rows))
self.assertEqual(3, len(cache))

# check update with params not added to cache
q = tbl.update().where(
tbl.c.val == 'some_val_1'
).values(val='updated_val_1')
await conn.execute(q)
self.assertEqual(3, len(cache))
cursor = await conn.execute(
select_by_val, value='updated_val_1'
)
row = await cursor.fetchone()
self.assertEqual('updated_val_1', row.val)

# check update with bound params added to cache
q = tbl.update().where(
tbl.c.val == bindparam('value')
).values(val=bindparam('update'))
await conn.execute(
q, value='some_val_2', update='updated_val_2'
)
self.assertEqual(4, len(cache))
cursor = await conn.execute(
select_by_val, value='updated_val_2'
)
row = await cursor.fetchone()
self.assertEqual('updated_val_2', row.val)

self.loop.run_until_complete(go())