Skip to content

Commit

Permalink
fix(python): constrain multipart upload size to fixed length (#2606)
Browse files Browse the repository at this point in the history
# Description

Object stores expected fixed lengths for all multipart upload parts
right up until the last part. The original logic just flushed when it
exceeded the threshold. Now, it flushes when the threshold is met
exclusively with the same fixed buffer, unless we're completing the
transaction, in which case the last piece is allowed to be smaller.

Bumps the constant to reflect that the minimum expected size by most
object stores is 5MiB. Also adds a UserWarning if a constant is
specified to be less.

Also releases the GIL in more places by moving the flushing logic to a
free function.

# Related Issue(s)
<!---
For example:

- closes #106
--->

Closes #2605 

# Documentation

<!---
Share links to useful documentation
--->

See:
[MultipartUpload](https://docs.rs/object_store/latest/object_store/trait.MultipartUpload.html)
docs
  • Loading branch information
abhiaagarwal authored Jun 19, 2024
1 parent b7b572b commit 6205f00
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 35 deletions.
90 changes: 66 additions & 24 deletions python/src/filesystem.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::PythonError;
use crate::utils::{delete_dir, rt, walk_tree};
use crate::utils::{delete_dir, rt, walk_tree, warn};
use crate::RawDeltaTable;
use deltalake::storage::object_store::{MultipartUpload, PutPayloadMut};
use deltalake::storage::{DynObjectStore, ListResult, ObjectStoreError, Path};
Expand All @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;

const DEFAULT_MAX_BUFFER_SIZE: usize = 4 * 1024 * 1024;
const DEFAULT_MAX_BUFFER_SIZE: usize = 5 * 1024 * 1024;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct FsConfig {
Expand Down Expand Up @@ -297,6 +297,7 @@ impl DeltaFileSystemHandler {
&self,
path: String,
#[allow(unused)] metadata: Option<HashMap<String, String>>,
py: Python<'_>,
) -> PyResult<ObjectOutputStream> {
let path = Self::parse_path(&path);
let max_buffer_size = self
Expand All @@ -306,6 +307,19 @@ impl DeltaFileSystemHandler {
.map_or(DEFAULT_MAX_BUFFER_SIZE, |v| {
v.parse::<usize>().unwrap_or(DEFAULT_MAX_BUFFER_SIZE)
});
if max_buffer_size < DEFAULT_MAX_BUFFER_SIZE {
warn(
py,
"UserWarning",
format!(
"You specified a `max_buffer_size` of {} bits less than {} bits. Most object
stores expect greater than that number, you may experience issues",
max_buffer_size, DEFAULT_MAX_BUFFER_SIZE
)
.as_str(),
Some(2),
)?;
}
let file = rt()
.block_on(ObjectOutputStream::try_new(
self.inner.clone(),
Expand Down Expand Up @@ -537,18 +551,37 @@ impl ObjectOutputStream {

Ok(())
}

fn abort(&mut self) -> PyResult<()> {
rt().block_on(self.upload.abort())
.map_err(PythonError::from)?;
Ok(())
}

fn upload_buffer(&mut self) -> PyResult<()> {
let payload = std::mem::take(&mut self.buffer).freeze();
match rt().block_on(self.upload.put_part(payload)) {
Ok(_) => Ok(()),
Err(err) => {
self.abort()?;
Err(PyIOError::new_err(err.to_string()))
}
}
}
}

#[pymethods]
impl ObjectOutputStream {
fn close(&mut self, py: Python<'_>) -> PyResult<()> {
self.closed = true;
if !self.buffer.is_empty() {
self.flush(py)?;
}
py.allow_threads(|| match rt().block_on(self.upload.complete()) {
Ok(_) => Ok(()),
Err(err) => Err(PyIOError::new_err(err.to_string())),
py.allow_threads(|| {
self.closed = true;
if !self.buffer.is_empty() {
self.upload_buffer()?;
}
match rt().block_on(self.upload.complete()) {
Ok(_) => Ok(()),
Err(err) => Err(PyIOError::new_err(err.to_string())),
}
})
}

Expand Down Expand Up @@ -596,24 +629,33 @@ impl ObjectOutputStream {
self.check_closed()?;
let py = data.py();
let bytes = data.as_bytes();
let len = bytes.len();
py.allow_threads(|| self.buffer.extend_from_slice(bytes));
if self.buffer.content_length() >= self.max_buffer_size {
self.flush(py)?;
}
Ok(len as i64)
py.allow_threads(|| {
let len = bytes.len();
for chunk in bytes.chunks(self.max_buffer_size) {
// this will never overflow
let remaining = self.max_buffer_size - self.buffer.content_length();
// if we have enough space to store this chunk, just append it
if chunk.len() < remaining {
self.buffer.extend_from_slice(chunk);
break;
}
// if we don't, fill as much as we can, flush the buffer, and then append the rest
// this won't panic since we've checked the size of the chunk
let (first, second) = chunk.split_at(remaining);
self.buffer.extend_from_slice(first);
self.upload_buffer()?;
// len(second) will always be < max_buffer_size, and we just
// emptied the buffer by flushing, so we won't overflow
// if len(chunk) just happened to be == remaining,
// the second slice is empty. this is a no-op
self.buffer.extend_from_slice(second);
}
Ok(len as i64)
})
}

fn flush(&mut self, py: Python<'_>) -> PyResult<()> {
let payload = std::mem::take(&mut self.buffer).freeze();
py.allow_threads(|| match rt().block_on(self.upload.put_part(payload)) {
Ok(_) => Ok(()),
Err(err) => {
rt().block_on(self.upload.abort())
.map_err(PythonError::from)?;
Err(PyIOError::new_err(err.to_string()))
}
})
py.allow_threads(|| self.upload_buffer())
}

fn fileno(&self) -> PyResult<()> {
Expand Down
18 changes: 7 additions & 11 deletions python/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use deltalake::kernel::{
};
use pyo3::exceptions::{PyException, PyNotImplementedError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use std::collections::HashMap;

use crate::utils::warn;

// PyO3 doesn't yet support converting classes with inheritance with Python
// objects within Rust code, which we need here. So for now, we implement
// the types with no inheritance. Later, we may add inheritance.
Expand Down Expand Up @@ -717,16 +718,11 @@ impl PySchema {
}

fn json<'py>(self_: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let warnings_warn = PyModule::import_bound(py, "warnings")?.getattr("warn")?;
let deprecation_warning =
PyModule::import_bound(py, "builtins")?.getattr("DeprecationWarning")?;
let kwargs: [(&str, Bound<'py, PyAny>); 2] = [
("category", deprecation_warning),
("stacklevel", 2.to_object(py).into_bound(py)),
];
warnings_warn.call(
("Schema.json() is deprecated. Use json.loads(Schema.to_json()) instead.",),
Some(&kwargs.into_py_dict_bound(py)),
warn(
py,
"DeprecationWarning",
"Schema.json() is deprecated. Use json.loads(Schema.to_json()) instead.",
Some(2),
)?;

let super_ = self_.as_ref();
Expand Down
19 changes: 19 additions & 0 deletions python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::sync::{Arc, OnceLock};
use deltalake::storage::{ListResult, ObjectStore, ObjectStoreError, ObjectStoreResult, Path};
use futures::future::{join_all, BoxFuture, FutureExt};
use futures::StreamExt;
use pyo3::types::{IntoPyDict, PyAnyMethods, PyModule};
use pyo3::{Bound, PyAny, PyResult, Python, ToPyObject};
use tokio::runtime::Runtime;

#[inline]
Expand Down Expand Up @@ -80,3 +82,20 @@ pub async fn delete_dir(storage: &dyn ObjectStore, prefix: &Path) -> ObjectStore
}
Ok(())
}

pub fn warn<'py>(
py: Python<'py>,
warning_type: &str,
message: &str,
stack_level: Option<u8>,
) -> PyResult<()> {
let warnings_warn = PyModule::import_bound(py, "warnings")?.getattr("warn")?;
let warning_type = PyModule::import_bound(py, "builtins")?.getattr(warning_type)?;
let stack_level = stack_level.unwrap_or(1);
let kwargs: [(&str, Bound<'py, PyAny>); 2] = [
("category", warning_type),
("stacklevel", stack_level.to_object(py).into_bound(py)),
];
warnings_warn.call((message,), Some(&kwargs.into_py_dict_bound(py)))?;
Ok(())
}
14 changes: 14 additions & 0 deletions python/tests/test_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ def test_roundtrip_azure_decoded_sas(azurite_sas_creds, sample_data: pa.Table):
assert dt.version() == 0


@pytest.mark.parametrize("storage_size", [1, 4 * 1024 * 1024, 5 * 1024 * 1024 - 1])
def test_warning_for_small_max_buffer_size(tmp_path, storage_size):
storage_opts = {"max_buffer_size": str(storage_size)}
store = DeltaStorageHandler(str(tmp_path.absolute()), options=storage_opts)
with pytest.warns(UserWarning) as warnings:
store.open_output_stream("test")

assert len(warnings) == 1
assert (
f"You specified a `max_buffer_size` of {storage_size} bits less than {5*1024*1024} bits"
in str(warnings[0].message)
)


def test_pickle_roundtrip(tmp_path):
store = DeltaStorageHandler(str(tmp_path.absolute()))

Expand Down

0 comments on commit 6205f00

Please sign in to comment.