Skip to content

Commit

Permalink
Add support for COPY IN
Browse files Browse the repository at this point in the history
This commit adds two new Connection methods: copy_to_table() and
copy_records_to_table() that allow copying data to the specified
table either in text or, in the latter case, record form.

Closes #123.
Closes #21.
  • Loading branch information
elprans committed May 11, 2017
1 parent 2c4e894 commit 10d95d4
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 10 deletions.
8 changes: 7 additions & 1 deletion asyncpg/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
import socket
import subprocess
import sys
import tempfile
import textwrap
import time
Expand Down Expand Up @@ -213,10 +214,15 @@ def start(self, wait=60, *, server_settings={}, **opts):
'pg_ctl start exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
else:
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
else:
stdout = subprocess.DEVNULL

self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
stdout=stdout, stderr=subprocess.STDOUT,
preexec_fn=ensure_dead_with_parent)

self._daemon_pid = self._daemon_process.pid
Expand Down
164 changes: 164 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import collections
import collections.abc
import struct
import time

Expand Down Expand Up @@ -451,6 +452,115 @@ async def copy_from_query(self, query, *args, output,

return await self._copy_out(copy_stmt, output, timeout)

async def copy_to_table(self, table_name, *, source,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
force_not_null=None, force_null=None,
encoding=None):
"""Copy data to the specified table.
:param str table_name:
The name of the table to copy data to.
:param source:
A :term:`path-like object <python:path-like object>`,
or a :term:`file-like object <python:file-like object>`, or
an :term:`asynchronous iterable <python:asynchronous iterable>`
that returns ``bytes``, or an object supporting the
:term:`buffer protocol <python:buffer protocol>`.
:param list columns:
An optional list of column names to copy.
:param str schema_name:
An optional schema name to qualify the table.
:param float timeout:
Optional timeout value in seconds.
The remaining kewyword arguments are ``COPY`` statement options,
see `COPY statement documentation`_ for details.
:return: The status string of the COPY command.
.. versionadded:: 0.11.0
.. _`COPY statement documentation`: https://www.postgresql.org/docs/\
current/static/sql-copy.html
"""
tabname = utils._quote_ident(table_name)
if schema_name:
tabname = utils._quote_ident(schema_name) + '.' + tabname

if columns:
cols = '({})'.format(
', '.join(utils._quote_ident(c) for c in columns))
else:
cols = ''

opts = self._format_copy_opts(
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
null=null, header=header, quote=quote, escape=escape,
force_not_null=force_not_null, force_null=force_null,
encoding=encoding
)

copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)

return await self._copy_in(copy_stmt, source, timeout)

async def copy_records_to_table(self, table_name, *, records,
columns=None, schema_name=None,
timeout=None):
"""Copy a list of records to the specified table using binary COPY.
:param str table_name:
The name of the table to copy data to.
:param records:
An iterable returning row tuples to copy into the table.
:param list columns:
An optional list of column names to copy.
:param str schema_name:
An optional schema name to qualify the table.
:param float timeout:
Optional timeout value in seconds.
:return: The status string of the COPY command.
.. versionadded:: 0.11.0
"""
tabname = utils._quote_ident(table_name)
if schema_name:
tabname = utils._quote_ident(schema_name) + '.' + tabname

if columns:
col_list = ', '.join(utils._quote_ident(c) for c in columns)
cols = '({})'.format(col_list)
else:
col_list = '*'
cols = ''

intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format(
tab=tabname, cols=col_list)

intro_ps = await self.prepare(intro_query)

opts = '(FORMAT binary)'

copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
tab=tabname, cols=cols, opts=opts)

return await self._copy_in_records(
copy_stmt, records, intro_ps._state, timeout)

def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
escape=None, force_quote=None, force_not_null=None,
Expand Down Expand Up @@ -519,6 +629,60 @@ async def _writer(data):
if opened_by_us:
f.close()

async def _copy_in(self, copy_stmt, source, timeout):
try:
path = compat.fspath(source)
except TypeError:
# source is not a path-like object
path = None

f = None
reader = None
data = None
opened_by_us = False
run_in_executor = self._loop.run_in_executor

if path is not None:
# a path
f = await run_in_executor(None, open, path, 'wb')
opened_by_us = True
elif hasattr(source, 'read'):
# file-like
f = source
elif isinstance(source, collections.abc.AsyncIterable):
# assuming calling output returns an awaitable.
reader = source
else:
# assuming source is an instance supporting the buffer protocol.
data = source

if f is not None:
# Copying from a file-like object.
class _Reader:
@compat.aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
data = await run_in_executor(None, f.read, 524288)
if len(data) == 0:
raise StopAsyncIteration
else:
return data

reader = _Reader()

try:
return await self._protocol.copy_in(
copy_stmt, reader, data, None, None, timeout)
finally:
if opened_by_us:
await run_in_executor(None, f.close)

async def _copy_in_records(self, copy_stmt, records, intro_stmt, timeout):
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_stmt, timeout)

async def set_type_codec(self, typename, *,
schema='public', encoder, decoder, binary=False):
"""Set an encoder/decoder pair for the specified data type.
Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/consts.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ DEF _BUFFER_FREELIST_SIZE = 256
DEF _RECORD_FREELIST_SIZE = 1024
DEF _MEMORY_FREELIST_SIZE = 1024
DEF _MAXINT32 = 2**31 - 1
DEF _COPY_BUFFER_SIZE = 524288
DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0"
7 changes: 7 additions & 0 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ cdef class CoreProtocol:
cdef _process__bind(self, char mtype)
cdef _process__copy_out(self, char mtype)
cdef _process__copy_out_data(self, char mtype)
cdef _process__copy_in(self, char mtype)
cdef _process__copy_in_data(self, char mtype)

cdef _parse_msg_authentication(self)
cdef _parse_msg_parameter_status(self)
Expand All @@ -124,6 +126,10 @@ cdef class CoreProtocol:
cdef _parse_msg_error_response(self, is_error)
cdef _parse_msg_command_complete(self)

cdef _write_copy_data_msg(self, object data)
cdef _write_copy_done_msg(self)
cdef _write_copy_fail_msg(self, str cause)

cdef _auth_password_message_cleartext(self)
cdef _auth_password_message_md5(self, bytes salt)

Expand Down Expand Up @@ -157,6 +163,7 @@ cdef class CoreProtocol:
cdef _close(self, str name, bint is_portal)
cdef _simple_query(self, str query)
cdef _copy_out(self, str copy_stmt)
cdef _copy_in(self, str copy_stmt)
cdef _terminate(self)

cdef _decode_row(self, const char* buf, ssize_t buf_len)
Expand Down
84 changes: 84 additions & 0 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ cdef class CoreProtocol:
state == PROTOCOL_COPY_OUT_DONE):
self._process__copy_out_data(mtype)

elif state == PROTOCOL_COPY_IN:
self._process__copy_in(mtype)

elif state == PROTOCOL_COPY_IN_DATA:
self._process__copy_in_data(mtype)

elif state == PROTOCOL_CANCELLED:
# discard all messages until the sync message
if mtype == b'E':
Expand Down Expand Up @@ -356,6 +362,33 @@ cdef class CoreProtocol:
self._parse_msg_ready_for_query()
self._push_result()

cdef _process__copy_in(self, char mtype):
if mtype == b'E':
self._parse_msg_error_response(True)

elif mtype == b'G':
# CopyInResponse
self._set_state(PROTOCOL_COPY_IN_DATA)
self.buffer.consume_message()

elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
self._push_result()

cdef _process__copy_in_data(self, char mtype):
if mtype == b'E':
self._parse_msg_error_response(True)

elif mtype == b'C':
# CommandComplete
self._parse_msg_command_complete()

elif mtype == b'Z':
# ReadyForQuery
self._parse_msg_ready_for_query()
self._push_result()

cdef _parse_msg_command_complete(self):
cdef:
char* cbuf
Expand Down Expand Up @@ -387,6 +420,42 @@ cdef class CoreProtocol:
self._on_result()
self.result = None

cdef _write_copy_data_msg(self, object data):
cdef:
WriteBuffer buf
object mview
Py_buffer *pybuf

mview = PyMemoryView_GetContiguous(data, cpython.PyBUF_SIMPLE, b'C')

try:
pybuf = PyMemoryView_GET_BUFFER(mview)

buf = WriteBuffer.new_message(b'd')
buf.write_cstr(<const char *>pybuf.buf, pybuf.len)
buf.end_message()
finally:
mview.release()

self._write(buf)

cdef _write_copy_done_msg(self):
cdef:
WriteBuffer buf

buf = WriteBuffer.new_message(b'c')
buf.end_message()
self._write(buf)

cdef _write_copy_fail_msg(self, str cause):
cdef:
WriteBuffer buf

buf = WriteBuffer.new_message(b'f')
buf.write_str(cause or '', self.encoding)
buf.end_message()
self._write(buf)

cdef _parse_data_msgs(self):
cdef:
ReadBuffer buf = self.buffer
Expand Down Expand Up @@ -592,6 +661,10 @@ cdef class CoreProtocol:
new_state == PROTOCOL_COPY_OUT_DONE):
self.state = new_state

elif (self.state == PROTOCOL_COPY_IN and
new_state == PROTOCOL_COPY_IN_DATA):
self.state = new_state

elif self.state == PROTOCOL_FAILED:
raise RuntimeError(
'cannot switch to state {}; '
Expand Down Expand Up @@ -810,6 +883,17 @@ cdef class CoreProtocol:
buf.end_message()
self._write(buf)

cdef _copy_in(self, str copy_stmt):
cdef WriteBuffer buf

self._ensure_connected()
self._set_state(PROTOCOL_COPY_IN)

buf = WriteBuffer.new_message(b'Q')
buf.write_str(copy_stmt, self.encoding)
buf.end_message()
self._write(buf)

cdef _terminate(self):
cdef WriteBuffer buf
self._ensure_connected()
Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ cdef class BaseProtocol(CoreProtocol):

str last_query

bint writing_paused
bint closing

readonly uint64_t queries_count
Expand All @@ -58,6 +59,7 @@ cdef class BaseProtocol(CoreProtocol):
cdef _on_result__simple_query(self, object waiter)
cdef _on_result__bind(self, object waiter)
cdef _on_result__copy_out(self, object waiter)
cdef _on_result__copy_in(self, object waiter)

cdef _handle_waiter_on_connection_lost(self, cause)

Expand Down
Loading

0 comments on commit 10d95d4

Please sign in to comment.