From 1d33ff625a455d8445d4adc03c0fdbbec53ec5df Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Mon, 9 Aug 2021 17:16:28 -0700 Subject: [PATCH] Add support for asynchronous iterables to copy_records_to_table() (#713) The `Connection.copy_records_to_table()` now allows the `records` argument to be an asynchronous iterable. Fixes: #689. --- asyncpg/connection.py | 31 +++++++++++++++---- asyncpg/protocol/protocol.pyx | 57 ++++++++++++++++++++++++----------- tests/test_copy.py | 23 ++++++++++++++ 3 files changed, 87 insertions(+), 24 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 4a656124..e01c6b65 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -872,6 +872,8 @@ async def copy_records_to_table(self, table_name, *, records, :param records: An iterable returning row tuples to copy into the table. + :term:`Asynchronous iterables ` + are also supported. :param list columns: An optional list of column names to copy. @@ -901,7 +903,28 @@ async def copy_records_to_table(self, table_name, *, records, >>> asyncio.get_event_loop().run_until_complete(run()) 'COPY 2' + Asynchronous record iterables are also supported: + + .. code-block:: pycon + + >>> import asyncpg + >>> import asyncio + >>> async def run(): + ... con = await asyncpg.connect(user='postgres') + ... async def record_gen(size): + ... for i in range(size): + ... yield (i,) + ... result = await con.copy_records_to_table( + ... 'mytable', records=record_gen(100)) + ... print(result) + ... + >>> asyncio.get_event_loop().run_until_complete(run()) + 'COPY 100' + .. versionadded:: 0.11.0 + + .. versionchanged:: 0.24.0 + The ``records`` argument may be an asynchronous iterable. """ tabname = utils._quote_ident(table_name) if schema_name: @@ -924,8 +947,8 @@ async def copy_records_to_table(self, table_name, *, records, copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format( tab=tabname, cols=cols, opts=opts) - return await self._copy_in_records( - copy_stmt, records, intro_ps._state, timeout) + return await self._protocol.copy_in( + copy_stmt, None, None, records, intro_ps._state, timeout) def _format_copy_opts(self, *, format=None, oids=None, freeze=None, delimiter=None, null=None, header=None, quote=None, @@ -1047,10 +1070,6 @@ async def __anext__(self): if opened_by_us: await run_in_executor(None, f.close) - async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout): - return await self._protocol.copy_in( - copy_stmt, None, None, records, intro_stmt, timeout) - async def set_type_codec(self, typename, *, schema='public', encoder, decoder, format='text'): diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 3a1594a5..dbe52e9e 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -13,7 +13,7 @@ cimport cpython import asyncio import builtins import codecs -import collections +import collections.abc import socket import time import weakref @@ -438,23 +438,44 @@ cdef class BaseProtocol(CoreProtocol): 'no binary format encoder for ' 'type {} (OID {})'.format(codec.name, codec.oid)) - for row in records: - # Tuple header - wbuf.write_int16(num_cols) - # Tuple data - for i in range(num_cols): - item = row[i] - if item is None: - wbuf.write_int32(-1) - else: - codec = cpython.PyTuple_GET_ITEM(codecs, i) - codec.encode(settings, wbuf, item) - - if wbuf.len() >= _COPY_BUFFER_SIZE: - with timer: - await self.writing_allowed.wait() - self._write_copy_data_msg(wbuf) - wbuf = WriteBuffer.new() + if isinstance(records, collections.abc.AsyncIterable): + async for row in records: + # Tuple header + wbuf.write_int16(num_cols) + # Tuple data + for i in range(num_cols): + item = row[i] + if item is None: + wbuf.write_int32(-1) + else: + codec = cpython.PyTuple_GET_ITEM( + codecs, i) + codec.encode(settings, wbuf, item) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() + else: + for row in records: + # Tuple header + wbuf.write_int16(num_cols) + # Tuple data + for i in range(num_cols): + item = row[i] + if item is None: + wbuf.write_int32(-1) + else: + codec = cpython.PyTuple_GET_ITEM( + codecs, i) + codec.encode(settings, wbuf, item) + + if wbuf.len() >= _COPY_BUFFER_SIZE: + with timer: + await self.writing_allowed.wait() + self._write_copy_data_msg(wbuf) + wbuf = WriteBuffer.new() # End of binary copy. wbuf.write_int16(-1) diff --git a/tests/test_copy.py b/tests/test_copy.py index dcac96ac..70c9388e 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -644,6 +644,29 @@ async def test_copy_records_to_table_1(self): finally: await self.con.execute('DROP TABLE copytab') + async def test_copy_records_to_table_async(self): + await self.con.execute(''' + CREATE TABLE copytab_async(a text, b int, c timestamptz); + ''') + + try: + date = datetime.datetime.now(tz=datetime.timezone.utc) + delta = datetime.timedelta(days=1) + + async def record_generator(): + for i in range(100): + yield ('a-{}'.format(i), i, date + delta) + + yield ('a-100', None, None) + + res = await self.con.copy_records_to_table( + 'copytab_async', records=record_generator()) + + self.assertEqual(res, 'COPY 101') + + finally: + await self.con.execute('DROP TABLE copytab_async') + async def test_copy_records_to_table_no_binary_codec(self): await self.con.execute(''' CREATE TABLE copytab(a uuid);