Skip to content

Commit

Permalink
Improve diagnostics of invalid executemany() input
Browse files Browse the repository at this point in the history
This adds a check that elements of sequence passed to `executemany()`
are proper sequences themselves and notes the offending sequence element
number in the exception message.  For example:

    await self.con.executemany(
        "INSERT INTO exmany (b) VALUES($1)"
        [(0,), ("bad",)],
    )

    DataError: invalid input for query argument $1 in element #1 of
               executemany() sequence: 'bad' ('str' object cannot be
               interpreted as an integer)

Fixes: #807
  • Loading branch information
elprans committed Nov 16, 2021
1 parent f900b73 commit b3f5c5c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
2 changes: 1 addition & 1 deletion asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef class PreparedStatementState:
bint have_text_cols
tuple rows_codecs

cdef _encode_bind_msg(self, args)
cdef _encode_bind_msg(self, args, int seqno = ?)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
Expand Down
36 changes: 31 additions & 5 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,25 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True

cdef _encode_bind_msg(self, args):
cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
WriteBuffer writer
Codec codec

if not cpython.PySequence_Check(args):
if seqno >= 0:
raise exceptions.DataError(
f'invalid input in executemany() argument sequence '
f'element #{seqno}: expected a sequence, got '
f'{type(args).__name__}'
)
else:
# Non executemany() callers do not pass user input directly,
# so bad input is a bug.
raise exceptions.InternalClientError(
f'Bind: expected a sequence, got {type(args).__name__}')

if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
Expand Down Expand Up @@ -159,19 +172,32 @@ cdef class PreparedStatementState:
except exceptions.InterfaceError as e:
# This is already a descriptive error, but annotate
# with argument name for clarity.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
raise e.with_msg(
f'query argument ${idx + 1}: {e.args[0]}') from None
f'query argument {pos}: {e.args[0]}'
) from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'

raise exceptions.DataError(
'invalid input for query argument'
' ${n}: {v} ({msg})'.format(
n=idx + 1, v=value_repr, msg=e)) from e
f'invalid input for query argument'
f' {pos}: {value_repr} ({e})'
) from e

if self.have_text_cols:
writer.write_int16(self.cols_num)
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ cdef class BaseProtocol(CoreProtocol):
# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
# control.
data_gen = (state._encode_bind_msg(b) for b in args)
data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args))
arg_bufs = iter(data_gen)

waiter = self._new_waiter(timeout)
Expand Down
29 changes: 24 additions & 5 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncpg

from asyncpg import _testbase as tb
from asyncpg.exceptions import UniqueViolationError
from asyncpg import exceptions


class TestExecuteScript(tb.ConnectedTestCase):
Expand Down Expand Up @@ -140,6 +140,25 @@ async def test_executemany_basic(self):
])

async def test_executemany_bad_input(self):
with self.assertRaisesRegex(
exceptions.DataError,
r"invalid input in executemany\(\) argument sequence element #1: "
r"expected a sequence",
):
await self.con.executemany('''
INSERT INTO exmany (b) VALUES($1)
''', [(0,), {1: 0}])

with self.assertRaisesRegex(
exceptions.DataError,
r"invalid input for query argument \$1 in element #1 of "
r"executemany\(\) sequence: 'bad'",
):
await self.con.executemany('''
INSERT INTO exmany (b) VALUES($1)
''', [(0,), ("bad",)])

async def test_executemany_error_in_input_gen(self):
bad_data = ([1 / 0] for v in range(10))

with self.assertRaises(ZeroDivisionError):
Expand All @@ -155,7 +174,7 @@ async def test_executemany_bad_input(self):
''', good_data)

async def test_executemany_server_failure(self):
with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [
Expand All @@ -165,7 +184,7 @@ async def test_executemany_server_failure(self):
self.assertEqual(result, [])

async def test_executemany_server_failure_after_writes(self):
with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [('a' * 32768, x) for x in range(10)] + [
Expand All @@ -187,7 +206,7 @@ def gen():
else:
yield 'a' * 32768, pos

with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', gen())
Expand Down Expand Up @@ -260,7 +279,7 @@ async def test_executemany_client_failure_in_transaction(self):

async def test_executemany_client_server_failure_conflict(self):
self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64)
with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, 0)
''', (('a' * 32768,) for y in range(4, -1, -1) if y / y))
Expand Down

0 comments on commit b3f5c5c

Please sign in to comment.