From 6fa0458ff3a2921a5b2f3757b7ed39cc29d51c45 Mon Sep 17 00:00:00 2001 From: joeydewaal <99046430+joeydewaal@users.noreply.github.com> Date: Sat, 25 Jan 2025 23:23:50 +0100 Subject: [PATCH] fix(Postgres) chunk pg_copy data (#3703) * fix(postgres) chunk pg_copy data * fix: cleanup after review --- sqlx-postgres/src/copy.rs | 22 +++++++++++++++------- sqlx-postgres/src/lib.rs | 3 +++ tests/postgres/postgres.rs | 22 +++++++++++++++++++++- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index fd9ed7e85c..1315ea0e20 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -129,6 +129,9 @@ impl PgPoolCopyExt for Pool { } } +// (1 GiB - 1) - 1 - length prefix (4 bytes) +pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4; + /// A connection in streaming `COPY FROM STDIN` mode. /// /// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. @@ -186,15 +189,20 @@ impl> PgCopyIn { /// Send a chunk of `COPY` data. /// + /// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6 + /// bytes) and may be partially sent if this call is cancelled. + /// /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { - self.conn - .as_deref_mut() - .expect("send_data: conn taken") - .inner - .stream - .send(CopyData(data)) - .await?; + for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) { + self.conn + .as_deref_mut() + .expect("send_data: conn taken") + .inner + .stream + .send(CopyData(chunk)) + .await?; + } Ok(self) } diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index 792f8bbdc0..bded75491c 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -34,6 +34,9 @@ mod value; #[doc(hidden)] pub mod any; +#[doc(hidden)] +pub use copy::PG_COPY_MAX_DATA_LEN; + #[cfg(feature = "migrate")] mod migrate; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 87a18db510..d89ecbb21e 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -3,7 +3,7 @@ use futures::{Stream, StreamExt, TryStreamExt}; use sqlx::postgres::types::Oid; use sqlx::postgres::{ PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener, - PgPoolOptions, PgRow, PgSeverity, Postgres, + PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN, }; use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; use sqlx_core::{bytes::Bytes, error::BoxDynError}; @@ -2042,3 +2042,23 @@ async fn test_issue_3052() { "expected encode error, got {too_large_error:?}", ); } + +#[sqlx_macros::test] +async fn test_pg_copy_chunked() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut row = "1".repeat(PG_COPY_MAX_DATA_LEN / 10 - 1); + row.push_str("\n"); + + // creates a payload with COPY_MAX_DATA_LEN + 1 as size + let mut payload = row.repeat(10); + payload.push_str("12345678\n"); + + assert_eq!(payload.len(), PG_COPY_MAX_DATA_LEN + 1); + + let mut copy = conn.copy_in_raw("COPY products(name) FROM STDIN").await?; + + assert!(copy.send(payload.as_bytes()).await.is_ok()); + assert!(copy.finish().await.is_ok()); + Ok(()) +}