Skip to content

Commit

Permalink
fix warnings and make _do_get_result coroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
jettify committed Jul 17, 2016
1 parent 60f9abb commit 0202d23
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
8 changes: 8 additions & 0 deletions aiomysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,14 @@ def select_db(self, db):
yield from self._execute_command(COMMAND.COM_INIT_DB, db)
yield from self._read_ok_packet()

@asyncio.coroutine
def show_warnings(self):
"""SHOW WARNINGS"""
yield from self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
result = MySQLResult(self)
yield from result.read()
return result.rows

def escape(self, obj):
""" Escape whatever value you pass to it"""
if isinstance(obj, str):
Expand Down
25 changes: 21 additions & 4 deletions aiomysql/cursors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import re
import warnings

from pymysql.err import (
Warning, Error, InterfaceError, DataError,
Expand Down Expand Up @@ -187,7 +188,7 @@ def nextset(self):
if not current_result.has_next:
return
yield from conn.next_result()
self._do_get_result()
yield from self._do_get_result()
return True

def _escape_args(self, args, conn):
Expand Down Expand Up @@ -457,8 +458,9 @@ def _query(self, q):
conn = self._get_db()
self._last_executed = q
yield from conn.query(q)
self._do_get_result()
yield from self._do_get_result()

@asyncio.coroutine
def _do_get_result(self):
conn = self._get_db()
self._rownumber = 0
Expand All @@ -468,6 +470,20 @@ def _do_get_result(self):
self._lastrowid = result.insert_id
self._rows = result.rows

if result.warning_count > 0:
yield from self._show_warnings(conn)

@asyncio.coroutine
def _show_warnings(self, conn):
if self._result and self._result.has_next:
return
ws = yield from conn.show_warnings()
if ws is None:
return
for w in ws:
msg = w[-1]
warnings.warn(str(msg), Warning, 4)

Warning = Warning
Error = Error
InterfaceError = InterfaceError
Expand Down Expand Up @@ -506,8 +522,9 @@ class _DictCursorMixin:
# You can override this to use OrderedDict or other dict-like types.
dict_type = dict

@asyncio.coroutine
def _do_get_result(self):
super()._do_get_result()
yield from super()._do_get_result()
fields = []
if self._description:
for f in self._result.fields:
Expand Down Expand Up @@ -565,7 +582,7 @@ def _query(self, q):
conn = self._get_db()
self._last_executed = q
yield from conn.query(q, unbuffered=True)
self._do_get_result()
yield from self._do_get_result()
return self._rowcount

@asyncio.coroutine
Expand Down

0 comments on commit 0202d23

Please sign in to comment.