Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve diagnostics of invalid executemany() input #848

Merged
merged 1 commit into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef class PreparedStatementState:
bint have_text_cols
tuple rows_codecs

cdef _encode_bind_msg(self, args)
cdef _encode_bind_msg(self, args, int seqno = ?)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
Expand Down
36 changes: 31 additions & 5 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,25 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True

cdef _encode_bind_msg(self, args):
cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
WriteBuffer writer
Codec codec

if not cpython.PySequence_Check(args):
if seqno >= 0:
raise exceptions.DataError(
f'invalid input in executemany() argument sequence '
f'element #{seqno}: expected a sequence, got '
f'{type(args).__name__}'
)
else:
# Non executemany() callers do not pass user input directly,
# so bad input is a bug.
raise exceptions.InternalClientError(
f'Bind: expected a sequence, got {type(args).__name__}')

if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
Expand Down Expand Up @@ -159,19 +172,32 @@ cdef class PreparedStatementState:
except exceptions.InterfaceError as e:
# This is already a descriptive error, but annotate
# with argument name for clarity.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
raise e.with_msg(
f'query argument ${idx + 1}: {e.args[0]}') from None
f'query argument {pos}: {e.args[0]}'
) from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
pos = f'${idx + 1}'
if seqno >= 0:
pos = (
f'{pos} in element #{seqno} of'
f' executemany() sequence'
)
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'

raise exceptions.DataError(
'invalid input for query argument'
' ${n}: {v} ({msg})'.format(
n=idx + 1, v=value_repr, msg=e)) from e
f'invalid input for query argument'
f' {pos}: {value_repr} ({e})'
) from e

if self.have_text_cols:
writer.write_int16(self.cols_num)
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ cdef class BaseProtocol(CoreProtocol):
# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
# control.
data_gen = (state._encode_bind_msg(b) for b in args)
data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args))
arg_bufs = iter(data_gen)

waiter = self._new_waiter(timeout)
Expand Down
29 changes: 24 additions & 5 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncpg

from asyncpg import _testbase as tb
from asyncpg.exceptions import UniqueViolationError
from asyncpg import exceptions


class TestExecuteScript(tb.ConnectedTestCase):
Expand Down Expand Up @@ -140,6 +140,25 @@ async def test_executemany_basic(self):
])

async def test_executemany_bad_input(self):
with self.assertRaisesRegex(
exceptions.DataError,
r"invalid input in executemany\(\) argument sequence element #1: "
r"expected a sequence",
):
await self.con.executemany('''
INSERT INTO exmany (b) VALUES($1)
''', [(0,), {1: 0}])

with self.assertRaisesRegex(
exceptions.DataError,
r"invalid input for query argument \$1 in element #1 of "
r"executemany\(\) sequence: 'bad'",
):
await self.con.executemany('''
INSERT INTO exmany (b) VALUES($1)
''', [(0,), ("bad",)])

async def test_executemany_error_in_input_gen(self):
bad_data = ([1 / 0] for v in range(10))

with self.assertRaises(ZeroDivisionError):
Expand All @@ -155,7 +174,7 @@ async def test_executemany_bad_input(self):
''', good_data)

async def test_executemany_server_failure(self):
with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [
Expand All @@ -165,7 +184,7 @@ async def test_executemany_server_failure(self):
self.assertEqual(result, [])

async def test_executemany_server_failure_after_writes(self):
with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', [('a' * 32768, x) for x in range(10)] + [
Expand All @@ -187,7 +206,7 @@ def gen():
else:
yield 'a' * 32768, pos

with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, $2)
''', gen())
Expand Down Expand Up @@ -260,7 +279,7 @@ async def test_executemany_client_failure_in_transaction(self):

async def test_executemany_client_server_failure_conflict(self):
self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64)
with self.assertRaises(UniqueViolationError):
with self.assertRaises(exceptions.UniqueViolationError):
await self.con.executemany('''
INSERT INTO exmany VALUES($1, 0)
''', (('a' * 32768,) for y in range(4, -1, -1) if y / y))
Expand Down