diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 4427bfdc..3906af25 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -29,7 +29,7 @@ cdef class PreparedStatementState: bint have_text_cols tuple rows_codecs - cdef _encode_bind_msg(self, args) + cdef _encode_bind_msg(self, args, int seqno = ?) cpdef _init_codecs(self) cdef _ensure_rows_decoder(self) cdef _ensure_args_encoder(self) diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index 5f1820de..63466db8 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -101,12 +101,25 @@ cdef class PreparedStatementState: def mark_closed(self): self.closed = True - cdef _encode_bind_msg(self, args): + cdef _encode_bind_msg(self, args, int seqno = -1): cdef: int idx WriteBuffer writer Codec codec + if not cpython.PySequence_Check(args): + if seqno >= 0: + raise exceptions.DataError( + f'invalid input in executemany() argument sequence ' + f'element #{seqno}: expected a sequence, got ' + f'{type(args).__name__}' + ) + else: + # Non executemany() callers do not pass user input directly, + # so bad input is a bug. + raise exceptions.InternalClientError( + f'Bind: expected a sequence, got {type(args).__name__}') + if len(args) > 32767: raise exceptions.InterfaceError( 'the number of query arguments cannot exceed 32767') @@ -159,19 +172,32 @@ cdef class PreparedStatementState: except exceptions.InterfaceError as e: # This is already a descriptive error, but annotate # with argument name for clarity. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) raise e.with_msg( - f'query argument ${idx + 1}: {e.args[0]}') from None + f'query argument {pos}: {e.args[0]}' + ) from None except Exception as e: # Everything else is assumed to be an encoding error # due to invalid input. + pos = f'${idx + 1}' + if seqno >= 0: + pos = ( + f'{pos} in element #{seqno} of' + f' executemany() sequence' + ) value_repr = repr(arg) if len(value_repr) > 40: value_repr = value_repr[:40] + '...' raise exceptions.DataError( - 'invalid input for query argument' - ' ${n}: {v} ({msg})'.format( - n=idx + 1, v=value_repr, msg=e)) from e + f'invalid input for query argument' + f' {pos}: {value_repr} ({e})' + ) from e if self.have_text_cols: writer.write_int16(self.cols_num) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index bbe8026e..bb548962 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -217,7 +217,7 @@ cdef class BaseProtocol(CoreProtocol): # Make sure the argument sequence is encoded lazily with # this generator expression to keep the memory pressure under # control. - data_gen = (state._encode_bind_msg(b) for b in args) + data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args)) arg_bufs = iter(data_gen) waiter = self._new_waiter(timeout) diff --git a/tests/test_execute.py b/tests/test_execute.py index 8cf0d2f2..78d8c124 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -9,7 +9,7 @@ import asyncpg from asyncpg import _testbase as tb -from asyncpg.exceptions import UniqueViolationError +from asyncpg import exceptions class TestExecuteScript(tb.ConnectedTestCase): @@ -140,6 +140,25 @@ async def test_executemany_basic(self): ]) async def test_executemany_bad_input(self): + with self.assertRaisesRegex( + exceptions.DataError, + r"invalid input in executemany\(\) argument sequence element #1: " + r"expected a sequence", + ): + await self.con.executemany(''' + INSERT INTO exmany (b) VALUES($1) + ''', [(0,), {1: 0}]) + + with self.assertRaisesRegex( + exceptions.DataError, + r"invalid input for query argument \$1 in element #1 of " + r"executemany\(\) sequence: 'bad'", + ): + await self.con.executemany(''' + INSERT INTO exmany (b) VALUES($1) + ''', [(0,), ("bad",)]) + + async def test_executemany_error_in_input_gen(self): bad_data = ([1 / 0] for v in range(10)) with self.assertRaises(ZeroDivisionError): @@ -155,7 +174,7 @@ async def test_executemany_bad_input(self): ''', good_data) async def test_executemany_server_failure(self): - with self.assertRaises(UniqueViolationError): + with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', [ @@ -165,7 +184,7 @@ async def test_executemany_server_failure(self): self.assertEqual(result, []) async def test_executemany_server_failure_after_writes(self): - with self.assertRaises(UniqueViolationError): + with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', [('a' * 32768, x) for x in range(10)] + [ @@ -187,7 +206,7 @@ def gen(): else: yield 'a' * 32768, pos - with self.assertRaises(UniqueViolationError): + with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, $2) ''', gen()) @@ -260,7 +279,7 @@ async def test_executemany_client_failure_in_transaction(self): async def test_executemany_client_server_failure_conflict(self): self.con._transport.set_write_buffer_limits(65536 * 64, 16384 * 64) - with self.assertRaises(UniqueViolationError): + with self.assertRaises(exceptions.UniqueViolationError): await self.con.executemany(''' INSERT INTO exmany VALUES($1, 0) ''', (('a' * 32768,) for y in range(4, -1, -1) if y / y))