diff --git a/python/src/filesystem.rs b/python/src/filesystem.rs index 6a26b3e6f4..453d05e480 100644 --- a/python/src/filesystem.rs +++ b/python/src/filesystem.rs @@ -1,5 +1,5 @@ use crate::error::PythonError; -use crate::utils::{delete_dir, rt, walk_tree}; +use crate::utils::{delete_dir, rt, walk_tree, warn}; use crate::RawDeltaTable; use deltalake::storage::object_store::{MultipartUpload, PutPayloadMut}; use deltalake::storage::{DynObjectStore, ListResult, ObjectStoreError, Path}; @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; -const DEFAULT_MAX_BUFFER_SIZE: usize = 4 * 1024 * 1024; +const DEFAULT_MAX_BUFFER_SIZE: usize = 5 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct FsConfig { @@ -297,6 +297,7 @@ impl DeltaFileSystemHandler { &self, path: String, #[allow(unused)] metadata: Option>, + py: Python<'_>, ) -> PyResult { let path = Self::parse_path(&path); let max_buffer_size = self @@ -306,6 +307,19 @@ impl DeltaFileSystemHandler { .map_or(DEFAULT_MAX_BUFFER_SIZE, |v| { v.parse::().unwrap_or(DEFAULT_MAX_BUFFER_SIZE) }); + if max_buffer_size < DEFAULT_MAX_BUFFER_SIZE { + warn( + py, + "UserWarning", + format!( + "You specified a `max_buffer_size` of {} bits less than {} bits. Most object + stores expect greater than that number, you may experience issues", + max_buffer_size, DEFAULT_MAX_BUFFER_SIZE + ) + .as_str(), + Some(2), + )?; + } let file = rt() .block_on(ObjectOutputStream::try_new( self.inner.clone(), @@ -537,18 +551,37 @@ impl ObjectOutputStream { Ok(()) } + + fn abort(&mut self) -> PyResult<()> { + rt().block_on(self.upload.abort()) + .map_err(PythonError::from)?; + Ok(()) + } + + fn upload_buffer(&mut self) -> PyResult<()> { + let payload = std::mem::take(&mut self.buffer).freeze(); + match rt().block_on(self.upload.put_part(payload)) { + Ok(_) => Ok(()), + Err(err) => { + self.abort()?; + Err(PyIOError::new_err(err.to_string())) + } + } + } } #[pymethods] impl ObjectOutputStream { fn close(&mut self, py: Python<'_>) -> PyResult<()> { - self.closed = true; - if !self.buffer.is_empty() { - self.flush(py)?; - } - py.allow_threads(|| match rt().block_on(self.upload.complete()) { - Ok(_) => Ok(()), - Err(err) => Err(PyIOError::new_err(err.to_string())), + py.allow_threads(|| { + self.closed = true; + if !self.buffer.is_empty() { + self.upload_buffer()?; + } + match rt().block_on(self.upload.complete()) { + Ok(_) => Ok(()), + Err(err) => Err(PyIOError::new_err(err.to_string())), + } }) } @@ -596,24 +629,33 @@ impl ObjectOutputStream { self.check_closed()?; let py = data.py(); let bytes = data.as_bytes(); - let len = bytes.len(); - py.allow_threads(|| self.buffer.extend_from_slice(bytes)); - if self.buffer.content_length() >= self.max_buffer_size { - self.flush(py)?; - } - Ok(len as i64) + py.allow_threads(|| { + let len = bytes.len(); + for chunk in bytes.chunks(self.max_buffer_size) { + // this will never overflow + let remaining = self.max_buffer_size - self.buffer.content_length(); + // if we have enough space to store this chunk, just append it + if chunk.len() < remaining { + self.buffer.extend_from_slice(chunk); + break; + } + // if we don't, fill as much as we can, flush the buffer, and then append the rest + // this won't panic since we've checked the size of the chunk + let (first, second) = chunk.split_at(remaining); + self.buffer.extend_from_slice(first); + self.upload_buffer()?; + // len(second) will always be < max_buffer_size, and we just + // emptied the buffer by flushing, so we won't overflow + // if len(chunk) just happened to be == remaining, + // the second slice is empty. this is a no-op + self.buffer.extend_from_slice(second); + } + Ok(len as i64) + }) } fn flush(&mut self, py: Python<'_>) -> PyResult<()> { - let payload = std::mem::take(&mut self.buffer).freeze(); - py.allow_threads(|| match rt().block_on(self.upload.put_part(payload)) { - Ok(_) => Ok(()), - Err(err) => { - rt().block_on(self.upload.abort()) - .map_err(PythonError::from)?; - Err(PyIOError::new_err(err.to_string())) - } - }) + py.allow_threads(|| self.upload_buffer()) } fn fileno(&self) -> PyResult<()> { diff --git a/python/src/schema.rs b/python/src/schema.rs index 6f1380709a..ba4ea4fa47 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -12,9 +12,10 @@ use deltalake::kernel::{ }; use pyo3::exceptions::{PyException, PyNotImplementedError, PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::IntoPyDict; use std::collections::HashMap; +use crate::utils::warn; + // PyO3 doesn't yet support converting classes with inheritance with Python // objects within Rust code, which we need here. So for now, we implement // the types with no inheritance. Later, we may add inheritance. @@ -717,16 +718,11 @@ impl PySchema { } fn json<'py>(self_: PyRef<'_, Self>, py: Python<'py>) -> PyResult> { - let warnings_warn = PyModule::import_bound(py, "warnings")?.getattr("warn")?; - let deprecation_warning = - PyModule::import_bound(py, "builtins")?.getattr("DeprecationWarning")?; - let kwargs: [(&str, Bound<'py, PyAny>); 2] = [ - ("category", deprecation_warning), - ("stacklevel", 2.to_object(py).into_bound(py)), - ]; - warnings_warn.call( - ("Schema.json() is deprecated. Use json.loads(Schema.to_json()) instead.",), - Some(&kwargs.into_py_dict_bound(py)), + warn( + py, + "DeprecationWarning", + "Schema.json() is deprecated. Use json.loads(Schema.to_json()) instead.", + Some(2), )?; let super_ = self_.as_ref(); diff --git a/python/src/utils.rs b/python/src/utils.rs index 6d0f69b242..5ec2fe0a65 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -3,6 +3,8 @@ use std::sync::{Arc, OnceLock}; use deltalake::storage::{ListResult, ObjectStore, ObjectStoreError, ObjectStoreResult, Path}; use futures::future::{join_all, BoxFuture, FutureExt}; use futures::StreamExt; +use pyo3::types::{IntoPyDict, PyAnyMethods, PyModule}; +use pyo3::{Bound, PyAny, PyResult, Python, ToPyObject}; use tokio::runtime::Runtime; #[inline] @@ -80,3 +82,20 @@ pub async fn delete_dir(storage: &dyn ObjectStore, prefix: &Path) -> ObjectStore } Ok(()) } + +pub fn warn<'py>( + py: Python<'py>, + warning_type: &str, + message: &str, + stack_level: Option, +) -> PyResult<()> { + let warnings_warn = PyModule::import_bound(py, "warnings")?.getattr("warn")?; + let warning_type = PyModule::import_bound(py, "builtins")?.getattr(warning_type)?; + let stack_level = stack_level.unwrap_or(1); + let kwargs: [(&str, Bound<'py, PyAny>); 2] = [ + ("category", warning_type), + ("stacklevel", stack_level.to_object(py).into_bound(py)), + ]; + warnings_warn.call((message,), Some(&kwargs.into_py_dict_bound(py)))?; + Ok(()) +} diff --git a/python/tests/test_fs.py b/python/tests/test_fs.py index b5926eece8..87b24d8577 100644 --- a/python/tests/test_fs.py +++ b/python/tests/test_fs.py @@ -232,6 +232,20 @@ def test_roundtrip_azure_decoded_sas(azurite_sas_creds, sample_data: pa.Table): assert dt.version() == 0 +@pytest.mark.parametrize("storage_size", [1, 4 * 1024 * 1024, 5 * 1024 * 1024 - 1]) +def test_warning_for_small_max_buffer_size(tmp_path, storage_size): + storage_opts = {"max_buffer_size": str(storage_size)} + store = DeltaStorageHandler(str(tmp_path.absolute()), options=storage_opts) + with pytest.warns(UserWarning) as warnings: + store.open_output_stream("test") + + assert len(warnings) == 1 + assert ( + f"You specified a `max_buffer_size` of {storage_size} bits less than {5*1024*1024} bits" + in str(warnings[0].message) + ) + + def test_pickle_roundtrip(tmp_path): store = DeltaStorageHandler(str(tmp_path.absolute()))