Skip to content

Commit

Permalink
Add support for asynchronous iterables to copy_records_to_table()
Browse files Browse the repository at this point in the history
The `Connection.copy_records_to_table()` now allows the `records`
argument to be an asynchronous iterable.

Fixes: #689.
  • Loading branch information
elprans committed Aug 2, 2021
1 parent d076169 commit c416100
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 24 deletions.
31 changes: 25 additions & 6 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ async def copy_records_to_table(self, table_name, *, records,
:param records:
An iterable returning row tuples to copy into the table.
:term:`Asynchronous iterables <python:asynchronous iterable>`
are also supported.
:param list columns:
An optional list of column names to copy.
Expand Down Expand Up @@ -901,7 +903,28 @@ async def copy_records_to_table(self, table_name, *, records,
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 2'
Asynchronous record iterables are also supported:
.. code-block:: pycon
>>> import asyncpg
>>> import asyncio
>>> async def run():
... con = await asyncpg.connect(user='postgres')
... async def record_gen(size):
... for i in range(size):
... yield (i,)
... result = await con.copy_records_to_table(
... 'mytable', records=record_gen(100))
... print(result)
...
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 100'
.. versionadded:: 0.11.0
.. versionchanged:: 0.24.0
The ``records`` argument may be an asynchronous iterable.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
Expand All @@ -924,8 +947,8 @@ async def copy_records_to_table(self, table_name, *, records,
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)

return await self._copy_in_records(
copy_stmt, records, intro_ps._state, timeout)
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_ps._state, timeout)

def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
Expand Down Expand Up @@ -1047,10 +1070,6 @@ async def __anext__(self):
if opened_by_us:
await run_in_executor(None, f.close)

async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_stmt, timeout)

async def set_type_codec(self, typename, *,
schema='public', encoder, decoder,
format='text'):
Expand Down
57 changes: 39 additions & 18 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cimport cpython
import asyncio
import builtins
import codecs
import collections
import collections.abc
import socket
import time
import weakref
Expand Down Expand Up @@ -438,23 +438,44 @@ cdef class BaseProtocol(CoreProtocol):
'no binary format encoder for '
'type {} (OID {})'.format(codec.name, codec.oid))

for row in records:
# Tuple header
wbuf.write_int16(<int16_t>num_cols)
# Tuple data
for i in range(num_cols):
item = row[i]
if item is None:
wbuf.write_int32(-1)
else:
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
codec.encode(settings, wbuf, item)

if wbuf.len() >= _COPY_BUFFER_SIZE:
with timer:
await self.writing_allowed.wait()
self._write_copy_data_msg(wbuf)
wbuf = WriteBuffer.new()
if isinstance(records, collections.abc.AsyncIterable):
async for row in records:
# Tuple header
wbuf.write_int16(<int16_t>num_cols)
# Tuple data
for i in range(num_cols):
item = row[i]
if item is None:
wbuf.write_int32(-1)
else:
codec = <Codec>cpython.PyTuple_GET_ITEM(
codecs, i)
codec.encode(settings, wbuf, item)

if wbuf.len() >= _COPY_BUFFER_SIZE:
with timer:
await self.writing_allowed.wait()
self._write_copy_data_msg(wbuf)
wbuf = WriteBuffer.new()
else:
for row in records:
# Tuple header
wbuf.write_int16(<int16_t>num_cols)
# Tuple data
for i in range(num_cols):
item = row[i]
if item is None:
wbuf.write_int32(-1)
else:
codec = <Codec>cpython.PyTuple_GET_ITEM(
codecs, i)
codec.encode(settings, wbuf, item)

if wbuf.len() >= _COPY_BUFFER_SIZE:
with timer:
await self.writing_allowed.wait()
self._write_copy_data_msg(wbuf)
wbuf = WriteBuffer.new()

# End of binary copy.
wbuf.write_int16(-1)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,29 @@ async def test_copy_records_to_table_1(self):
finally:
await self.con.execute('DROP TABLE copytab')

async def test_copy_records_to_table_async(self):
await self.con.execute('''
CREATE TABLE copytab_async(a text, b int, c timestamptz);
''')

try:
date = datetime.datetime.now(tz=datetime.timezone.utc)
delta = datetime.timedelta(days=1)

async def record_generator():
for i in range(100):
yield ('a-{}'.format(i), i, date + delta)

yield ('a-100', None, None)

res = await self.con.copy_records_to_table(
'copytab_async', records=record_generator())

self.assertEqual(res, 'COPY 101')

finally:
await self.con.execute('DROP TABLE copytab_async')

async def test_copy_records_to_table_no_binary_codec(self):
await self.con.execute('''
CREATE TABLE copytab(a uuid);
Expand Down

0 comments on commit c416100

Please sign in to comment.