Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for asynchronous iterables to copy_records_to_table() #713

Merged
merged 2 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ async def copy_records_to_table(self, table_name, *, records,

:param records:
An iterable returning row tuples to copy into the table.
:term:`Asynchronous iterables <python:asynchronous iterable>`
are also supported.

:param list columns:
An optional list of column names to copy.
Expand Down Expand Up @@ -901,7 +903,28 @@ async def copy_records_to_table(self, table_name, *, records,
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 2'

Asynchronous record iterables are also supported:

.. code-block:: pycon

>>> import asyncpg
>>> import asyncio
>>> async def run():
... con = await asyncpg.connect(user='postgres')
... async def record_gen(size):
... for i in range(size):
... yield (i,)
... result = await con.copy_records_to_table(
... 'mytable', records=record_gen(100))
... print(result)
...
>>> asyncio.get_event_loop().run_until_complete(run())
'COPY 100'

.. versionadded:: 0.11.0

.. versionchanged:: 0.24.0
The ``records`` argument may be an asynchronous iterable.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
Expand All @@ -924,8 +947,8 @@ async def copy_records_to_table(self, table_name, *, records,
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)

return await self._copy_in_records(
copy_stmt, records, intro_ps._state, timeout)
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_ps._state, timeout)

def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
Expand Down Expand Up @@ -1047,10 +1070,6 @@ async def __anext__(self):
if opened_by_us:
await run_in_executor(None, f.close)

async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_stmt, timeout)

async def set_type_codec(self, typename, *,
schema='public', encoder, decoder,
format='text'):
Expand Down
57 changes: 39 additions & 18 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ cimport cpython
import asyncio
import builtins
import codecs
import collections
import collections.abc
elprans marked this conversation as resolved.
Show resolved Hide resolved
import socket
import time
import weakref
Expand Down Expand Up @@ -438,23 +438,44 @@ cdef class BaseProtocol(CoreProtocol):
'no binary format encoder for '
'type {} (OID {})'.format(codec.name, codec.oid))

for row in records:
# Tuple header
wbuf.write_int16(<int16_t>num_cols)
# Tuple data
for i in range(num_cols):
item = row[i]
if item is None:
wbuf.write_int32(-1)
else:
codec = <Codec>cpython.PyTuple_GET_ITEM(codecs, i)
codec.encode(settings, wbuf, item)

if wbuf.len() >= _COPY_BUFFER_SIZE:
with timer:
await self.writing_allowed.wait()
self._write_copy_data_msg(wbuf)
wbuf = WriteBuffer.new()
if isinstance(records, collections.abc.AsyncIterable):
async for row in records:
# Tuple header
wbuf.write_int16(<int16_t>num_cols)
# Tuple data
for i in range(num_cols):
item = row[i]
if item is None:
wbuf.write_int32(-1)
else:
codec = <Codec>cpython.PyTuple_GET_ITEM(
codecs, i)
codec.encode(settings, wbuf, item)

if wbuf.len() >= _COPY_BUFFER_SIZE:
with timer:
await self.writing_allowed.wait()
self._write_copy_data_msg(wbuf)
wbuf = WriteBuffer.new()
else:
elprans marked this conversation as resolved.
Show resolved Hide resolved
for row in records:
# Tuple header
wbuf.write_int16(<int16_t>num_cols)
# Tuple data
for i in range(num_cols):
item = row[i]
if item is None:
wbuf.write_int32(-1)
else:
codec = <Codec>cpython.PyTuple_GET_ITEM(
codecs, i)
codec.encode(settings, wbuf, item)

if wbuf.len() >= _COPY_BUFFER_SIZE:
with timer:
await self.writing_allowed.wait()
self._write_copy_data_msg(wbuf)
wbuf = WriteBuffer.new()

# End of binary copy.
wbuf.write_int16(-1)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,29 @@ async def test_copy_records_to_table_1(self):
finally:
await self.con.execute('DROP TABLE copytab')

async def test_copy_records_to_table_async(self):
await self.con.execute('''
CREATE TABLE copytab_async(a text, b int, c timestamptz);
''')

try:
date = datetime.datetime.now(tz=datetime.timezone.utc)
delta = datetime.timedelta(days=1)

async def record_generator():
for i in range(100):
yield ('a-{}'.format(i), i, date + delta)

yield ('a-100', None, None)

res = await self.con.copy_records_to_table(
'copytab_async', records=record_generator())

self.assertEqual(res, 'COPY 101')

finally:
await self.con.execute('DROP TABLE copytab_async')

async def test_copy_records_to_table_no_binary_codec(self):
await self.con.execute('''
CREATE TABLE copytab(a uuid);
Expand Down