diff --git a/crates/deltalake-core/src/lib.rs b/crates/deltalake-core/src/lib.rs index dfdbac97d6..6425ec3282 100644 --- a/crates/deltalake-core/src/lib.rs +++ b/crates/deltalake-core/src/lib.rs @@ -71,6 +71,7 @@ #![deny(warnings)] #![deny(missing_docs)] #![allow(rustdoc::invalid_html_tags)] +#![allow(clippy::nonminimal_bool)] #[cfg(all(feature = "parquet", feature = "parquet2"))] compile_error!( diff --git a/crates/deltalake-core/src/operations/delete.rs b/crates/deltalake-core/src/operations/delete.rs index b6c94f423b..35cf0b858f 100644 --- a/crates/deltalake-core/src/operations/delete.rs +++ b/crates/deltalake-core/src/operations/delete.rs @@ -172,6 +172,7 @@ async fn excute_non_empty_expr( None, writer_properties, false, + false, ) .await?; diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 8b0dd56708..061c2eb912 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -1013,6 +1013,7 @@ async fn execute( None, writer_properties, safe_cast, + false, ) .await?; diff --git a/crates/deltalake-core/src/operations/mod.rs b/crates/deltalake-core/src/operations/mod.rs index a81e16578f..473ef1451b 100644 --- a/crates/deltalake-core/src/operations/mod.rs +++ b/crates/deltalake-core/src/operations/mod.rs @@ -13,6 +13,7 @@ use self::vacuum::VacuumBuilder; use crate::errors::{DeltaResult, DeltaTableError}; use crate::table::builder::DeltaTableBuilder; use crate::DeltaTable; +use std::collections::HashMap; #[cfg(all(feature = "arrow", feature = "parquet"))] pub mod convert_to_delta; @@ -73,6 +74,22 @@ impl DeltaOps { } } + /// try from uri with storage options + pub async fn try_from_uri_with_storage_options( + uri: impl AsRef, + storage_options: HashMap, + ) -> DeltaResult { + let mut table = DeltaTableBuilder::from_uri(uri) + .with_storage_options(storage_options) + .build()?; + // We allow for uninitialized locations, since we may want to create the table + match table.load().await { + Ok(_) => Ok(table.into()), + Err(DeltaTableError::NotATable(_)) => Ok(table.into()), + Err(err) => Err(err), + } + } + /// Create a new [`DeltaOps`] instance, backed by an un-initialized in memory table /// /// Using this will not persist any changes beyond the lifetime of the table object. diff --git a/crates/deltalake-core/src/operations/update.rs b/crates/deltalake-core/src/operations/update.rs index 907dec5998..fa44724beb 100644 --- a/crates/deltalake-core/src/operations/update.rs +++ b/crates/deltalake-core/src/operations/update.rs @@ -363,6 +363,7 @@ async fn execute( None, writer_properties, safe_cast, + false, ) .await?; diff --git a/crates/deltalake-core/src/operations/write.rs b/crates/deltalake-core/src/operations/write.rs index cb68b72bb2..8fd8ddd99e 100644 --- a/crates/deltalake-core/src/operations/write.rs +++ b/crates/deltalake-core/src/operations/write.rs @@ -43,7 +43,7 @@ use super::writer::{DeltaWriter, WriterConfig}; use super::{transaction::commit, CreateBuilder}; use crate::delta_datafusion::DeltaDataChecker; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, Add, Remove, StructType}; +use crate::kernel::{Action, Add, Metadata, Remove, StructType}; use crate::logstore::LogStoreRef; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; @@ -103,12 +103,20 @@ pub struct WriteBuilder { write_batch_size: Option, /// RecordBatches to be written into the table batches: Option>, + /// whether to overwrite the schema + overwrite_schema: bool, /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) safe_cast: bool, /// Parquet writer properties writer_properties: Option, /// Additional metadata to be added to commit app_metadata: Option>, + /// Name of the table, only used when table doesn't exist yet + name: Option, + /// Description of the table, only used when table doesn't exist yet + description: Option, + /// Configurations of the delta table, only used when table doesn't exist + configuration: HashMap>, } impl WriteBuilder { @@ -126,8 +134,12 @@ impl WriteBuilder { write_batch_size: None, batches: None, safe_cast: false, + overwrite_schema: false, writer_properties: None, app_metadata: None, + name: None, + description: None, + configuration: Default::default(), } } @@ -137,6 +149,12 @@ impl WriteBuilder { self } + /// Add overwrite_schema + pub fn with_overwrite_schema(mut self, overwrite_schema: bool) -> Self { + self.overwrite_schema = overwrite_schema; + self + } + /// When using `Overwrite` mode, replace data that matches a predicate pub fn with_replace_where(mut self, predicate: impl Into) -> Self { self.predicate = Some(predicate.into()); @@ -205,6 +223,31 @@ impl WriteBuilder { self } + /// Specify the table name. Optionally qualified with + /// a database name [database_name.] table_name. + pub fn with_table_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + + /// Comment to describe the table. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Set configuration on created table + pub fn with_configuration( + mut self, + configuration: impl IntoIterator, Option>)>, + ) -> Self { + self.configuration = configuration + .into_iter() + .map(|(k, v)| (k.into(), v.map(|s| s.into()))) + .collect(); + self + } + async fn check_preconditions(&self) -> DeltaResult> { match self.log_store.is_delta_table_location().await? { true => { @@ -229,10 +272,20 @@ impl WriteBuilder { }?; let mut builder = CreateBuilder::new() .with_log_store(self.log_store.clone()) - .with_columns(schema.fields().clone()); + .with_columns(schema.fields().clone()) + .with_configuration(self.configuration.clone()); if let Some(partition_columns) = self.partition_columns.as_ref() { builder = builder.with_partition_columns(partition_columns.clone()) } + + if let Some(name) = self.name.as_ref() { + builder = builder.with_table_name(name.clone()); + }; + + if let Some(desc) = self.description.as_ref() { + builder = builder.with_comment(desc.clone()); + }; + let (_, actions, _) = builder.into_table_and_actions()?; Ok(actions) } @@ -251,6 +304,7 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, + overwrite_schema: bool, ) -> DeltaResult> { let invariants = snapshot .current_metadata() @@ -258,7 +312,11 @@ pub(crate) async fn write_execution_plan( .unwrap_or_default(); // Use input schema to prevent wrapping partitions columns into a dictionary. - let schema = snapshot.input_schema().unwrap_or(plan.schema()); + let schema: ArrowSchemaRef = if overwrite_schema { + plan.schema() + } else { + snapshot.input_schema().unwrap_or(plan.schema()) + }; let checker = DeltaDataChecker::new(invariants); @@ -339,13 +397,14 @@ impl std::future::IntoFuture for WriteBuilder { Ok(this.partition_columns.unwrap_or_default()) }?; + let mut schema: ArrowSchemaRef = arrow_schema::Schema::empty().into(); let plan = if let Some(plan) = this.input { Ok(plan) } else if let Some(batches) = this.batches { if batches.is_empty() { Err(WriteError::MissingData) } else { - let schema = batches[0].schema(); + schema = batches[0].schema(); let table_schema = this .snapshot .physical_arrow_schema(this.log_store.object_store().clone()) @@ -353,9 +412,11 @@ impl std::future::IntoFuture for WriteBuilder { .or_else(|_| this.snapshot.arrow_schema()) .unwrap_or(schema.clone()); - if !can_cast_batch(schema.fields(), table_schema.fields()) { + if !can_cast_batch(schema.fields(), table_schema.fields()) + && !(this.overwrite_schema && matches!(this.mode, SaveMode::Overwrite)) + { return Err(DeltaTableError::Generic( - "Updating table schema not yet implemented".to_string(), + "Schema of data does not match table schema".to_string(), )); }; @@ -390,7 +451,7 @@ impl std::future::IntoFuture for WriteBuilder { vec![batches] }; - Ok(Arc::new(MemoryExec::try_new(&data, schema, None)?) + Ok(Arc::new(MemoryExec::try_new(&data, schema.clone(), None)?) as Arc) } } else { @@ -415,12 +476,31 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties, this.safe_cast, + this.overwrite_schema, ) .await?; actions.extend(add_actions.into_iter().map(Action::Add)); // Collect remove actions if we are overwriting the table if matches!(this.mode, SaveMode::Overwrite) { + // Update metadata with new schema + let table_schema = this + .snapshot + .physical_arrow_schema(this.log_store.object_store().clone()) + .await + .or_else(|_| this.snapshot.arrow_schema()) + .unwrap_or(schema.clone()); + + if schema != table_schema { + let mut metadata = this + .snapshot + .current_metadata() + .ok_or(DeltaTableError::NoMetadata)? + .clone(); + metadata.schema = schema.clone().try_into()?; + let metadata_action = Metadata::try_from(metadata)?; + actions.push(Action::Metadata(metadata_action)); + } // This should never error, since now() will always be larger than UNIX_EPOCH let deletion_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -445,7 +525,10 @@ impl std::future::IntoFuture for WriteBuilder { match this.predicate { Some(_pred) => { - todo!("Overwriting data based on predicate is not yet implemented") + return Err(DeltaTableError::Generic( + "Overwriting data based on predicate is not yet implemented" + .to_string(), + )); } _ => { let remove_actions = this diff --git a/crates/deltalake-core/src/protocol/checkpoints.rs b/crates/deltalake-core/src/protocol/checkpoints.rs index 837483c35c..ef521159d9 100644 --- a/crates/deltalake-core/src/protocol/checkpoints.rs +++ b/crates/deltalake-core/src/protocol/checkpoints.rs @@ -468,22 +468,19 @@ fn apply_stats_conversion( data_type: &DataType, ) { if path.len() == 1 { - match data_type { - DataType::Primitive(PrimitiveType::Timestamp) => { - let v = context.get_mut(&path[0]); - - if let Some(v) = v { - let ts = v - .as_str() - .and_then(|s| time_utils::timestamp_micros_from_stats_string(s).ok()) - .map(|n| Value::Number(serde_json::Number::from(n))); - - if let Some(ts) = ts { - *v = ts; - } + if let DataType::Primitive(PrimitiveType::Timestamp) = data_type { + let v = context.get_mut(&path[0]); + + if let Some(v) = v { + let ts = v + .as_str() + .and_then(|s| time_utils::timestamp_micros_from_stats_string(s).ok()) + .map(|n| Value::Number(serde_json::Number::from(n))); + + if let Some(ts) = ts { + *v = ts; } } - _ => { /* noop */ } } } else { let next_context = context.get_mut(&path[0]).and_then(|v| v.as_object_mut()); diff --git a/crates/deltalake-core/src/writer/utils.rs b/crates/deltalake-core/src/writer/utils.rs index 49c3c6bfee..173340f368 100644 --- a/crates/deltalake-core/src/writer/utils.rs +++ b/crates/deltalake-core/src/writer/utils.rs @@ -5,13 +5,14 @@ use std::io::Write; use std::sync::Arc; use arrow::array::{ - as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array, Array, + as_boolean_array, as_generic_binary_array, as_largestring_array, as_primitive_array, + as_string_array, Array, }; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, + DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow::json::ReaderBuilder; use arrow::record_batch::*; @@ -184,7 +185,10 @@ pub(crate) fn stringified_partition_value( DataType::UInt16 => as_primitive_array::(arr).value(0).to_string(), DataType::UInt32 => as_primitive_array::(arr).value(0).to_string(), DataType::UInt64 => as_primitive_array::(arr).value(0).to_string(), + DataType::Float32 => as_primitive_array::(arr).value(0).to_string(), + DataType::Float64 => as_primitive_array::(arr).value(0).to_string(), DataType::Utf8 => as_string_array(arr).value(0).to_string(), + DataType::LargeUtf8 => as_largestring_array(arr).value(0).to_string(), DataType::Boolean => as_boolean_array(arr).value(0).to_string(), DataType::Date32 => as_primitive_array::(arr) .value_as_date(0) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index f751afa36f..d7c0e1a8f9 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -140,6 +140,19 @@ def write_new_deltalake( configuration: Optional[Mapping[str, Optional[str]]], storage_options: Optional[Dict[str, str]], ) -> None: ... +def write_to_deltalake( + table_uri: str, + data: pyarrow.RecordBatchReader, + partition_by: Optional[List[str]], + mode: str, + max_rows_per_group: int, + overwrite_schema: bool, + predicate: Optional[str], + name: Optional[str], + description: Optional[str], + configuration: Optional[Mapping[str, Optional[str]]], + storage_options: Optional[Dict[str, str]], +) -> None: ... def convert_to_deltalake( uri: str, partition_by: Optional[pyarrow.Schema], diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 065803f5c7..626fb1a5d9 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -17,6 +17,7 @@ Optional, Tuple, Union, + overload, ) from urllib.parse import unquote @@ -37,7 +38,8 @@ from ._internal import DeltaDataChecker as _DeltaDataChecker from ._internal import batch_distinct from ._internal import convert_to_deltalake as _convert_to_deltalake -from ._internal import write_new_deltalake as _write_new_deltalake +from ._internal import write_new_deltalake as write_deltalake_pyarrow +from ._internal import write_to_deltalake as write_deltalake_rust from .exceptions import DeltaProtocolError, TableNotFoundError from .schema import ( convert_pyarrow_dataset, @@ -67,6 +69,68 @@ class AddAction: stats: str +@overload +def write_deltalake( + table_or_uri: Union[str, Path, DeltaTable], + data: Union[ + "pd.DataFrame", + ds.Dataset, + pa.Table, + pa.RecordBatch, + Iterable[pa.RecordBatch], + RecordBatchReader, + ], + *, + schema: Optional[pa.Schema] = ..., + partition_by: Optional[Union[List[str], str]] = ..., + filesystem: Optional[pa_fs.FileSystem] = None, + mode: Literal["error", "append", "overwrite", "ignore"] = ..., + file_options: Optional[ds.ParquetFileWriteOptions] = ..., + max_partitions: Optional[int] = ..., + max_open_files: int = ..., + max_rows_per_file: int = ..., + min_rows_per_group: int = ..., + max_rows_per_group: int = ..., + name: Optional[str] = ..., + description: Optional[str] = ..., + configuration: Optional[Mapping[str, Optional[str]]] = ..., + overwrite_schema: bool = ..., + storage_options: Optional[Dict[str, str]] = ..., + partition_filters: Optional[List[Tuple[str, str, Any]]] = ..., + large_dtypes: bool = ..., + engine: Literal["pyarrow"] = ..., +) -> None: + ... + + +@overload +def write_deltalake( + table_or_uri: Union[str, Path, DeltaTable], + data: Union[ + "pd.DataFrame", + ds.Dataset, + pa.Table, + pa.RecordBatch, + Iterable[pa.RecordBatch], + RecordBatchReader, + ], + *, + schema: Optional[pa.Schema] = ..., + partition_by: Optional[Union[List[str], str]] = ..., + mode: Literal["error", "append", "overwrite", "ignore"] = ..., + max_rows_per_group: int = ..., + name: Optional[str] = ..., + description: Optional[str] = ..., + configuration: Optional[Mapping[str, Optional[str]]] = ..., + overwrite_schema: bool = ..., + storage_options: Optional[Dict[str, str]] = ..., + predicate: Optional[str] = ..., + large_dtypes: bool = ..., + engine: Literal["rust"], +) -> None: + ... + + def write_deltalake( table_or_uri: Union[str, Path, DeltaTable], data: Union[ @@ -94,7 +158,9 @@ def write_deltalake( overwrite_schema: bool = False, storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, + predicate: Optional[str] = None, large_dtypes: bool = False, + engine: Literal["pyarrow", "rust"] = "pyarrow", ) -> None: """Write to a Delta Lake table @@ -140,20 +206,20 @@ def write_deltalake( file_options: Optional write options for Parquet (ParquetFileWriteOptions). Can be provided with defaults using ParquetFileWriteOptions().make_write_options(). Please refer to https://github.com/apache/arrow/blob/master/python/pyarrow/_dataset_parquet.pyx#L492-L533 - for the list of available options - max_partitions: the maximum number of partitions that will be used. + for the list of available options. Only used in pyarrow engine. + max_partitions: the maximum number of partitions that will be used. Only used in pyarrow engine. max_open_files: Limits the maximum number of files that can be left open while writing. If an attempt is made to open too many files then the least recently used file will be closed. If this setting is set too low you may end up fragmenting your - data into many small files. + data into many small files. Only used in pyarrow engine. max_rows_per_file: Maximum number of rows per file. If greater than 0 then this will limit how many rows are placed in any single file. Otherwise there will be no limit and one file will be created in each output directory unless files need to be closed to respect max_open_files min_rows_per_group: Minimum number of rows per group. When the value is set, the dataset writer will batch incoming data and only write the row groups to the disk - when sufficient rows have accumulated. + when sufficient rows have accumulated. Only used in pyarrow engine. max_rows_per_group: Maximum number of rows per group. If the value is set, then the dataset writer may split up large incoming batches into multiple row groups. If this value is set, then min_rows_per_group should also be set. @@ -162,16 +228,22 @@ def write_deltalake( configuration: A map containing configuration options for the metadata action. overwrite_schema: If True, allows updating the schema of the table. storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined. - partition_filters: the partition filters that will be used for partition overwrite. + predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. + partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input """ - table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) + if table is not None: + storage_options = table._storage_options or {} + storage_options.update(storage_options or {}) - # We need to write against the latest table version - if table: table.update_incremental() + __enforce_append_only(table=table, configuration=configuration, mode=mode) + + if isinstance(partition_by, str): + partition_by = [partition_by] + if isinstance(data, RecordBatchReader): data = convert_pyarrow_recordbatchreader(data, large_dtypes) elif isinstance(data, pa.RecordBatch): @@ -182,9 +254,13 @@ def write_deltalake( data = convert_pyarrow_dataset(data, large_dtypes) elif _has_pandas and isinstance(data, pd.DataFrame): if schema is not None: - data = pa.Table.from_pandas(data, schema=schema) + data = convert_pyarrow_table( + pa.Table.from_pandas(data, schema=schema), large_dtypes=large_dtypes + ) else: - data = convert_pyarrow_table(pa.Table.from_pandas(data), False) + data = convert_pyarrow_table( + pa.Table.from_pandas(data), large_dtypes=large_dtypes + ) elif isinstance(data, Iterable): if schema is None: raise ValueError("You must provide schema if data is Iterable") @@ -196,204 +272,224 @@ def write_deltalake( if schema is None: schema = data.schema - if filesystem is not None: - raise NotImplementedError("Filesystem support is not yet implemented. #570") + if engine == "rust": + if table is not None and mode == "ignore": + return - if table is not None: - storage_options = table._storage_options or {} - storage_options.update(storage_options or {}) + data = RecordBatchReader.from_batches(schema, (batch for batch in data)) + write_deltalake_rust( + table_uri=table_uri, + data=data, + partition_by=partition_by, + mode=mode, + max_rows_per_group=max_rows_per_group, + overwrite_schema=overwrite_schema, + predicate=predicate, + name=name, + description=description, + configuration=configuration, + storage_options=storage_options, + ) + if table: + table.update_incremental() + + elif engine == "pyarrow": + # We need to write against the latest table version + if filesystem is not None: + raise NotImplementedError( + "Filesystem support is not yet implemented. #570" + ) - filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) + filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) - __enforce_append_only(table=table, configuration=configuration, mode=mode) + if table: # already exists + if schema != table.schema().to_pyarrow( + as_large_types=large_dtypes + ) and not (mode == "overwrite" and overwrite_schema): + raise ValueError( + "Schema of data does not match table schema\n" + f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" + ) + if mode == "error": + raise AssertionError("DeltaTable already exists.") + elif mode == "ignore": + return - if isinstance(partition_by, str): - partition_by = [partition_by] + current_version = table.version() - if table: # already exists - if schema != table.schema().to_pyarrow(as_large_types=large_dtypes) and not ( - mode == "overwrite" and overwrite_schema - ): - raise ValueError( - "Schema of data does not match table schema\n" - f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" - ) + if partition_by: + assert partition_by == table.metadata().partition_columns + else: + partition_by = table.metadata().partition_columns - if mode == "error": - raise AssertionError("DeltaTable already exists.") - elif mode == "ignore": - return + else: # creating a new table + current_version = -1 - current_version = table.version() + dtype_map = { + pa.large_string(): pa.string(), + } + + def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: + try: + return dtype_map[dtype] + except KeyError: + return dtype if partition_by: - assert partition_by == table.metadata().partition_columns + if PYARROW_MAJOR_VERSION < 12: + partition_schema = pa.schema( + [ + pa.field(name, _large_to_normal_dtype(schema.field(name).type)) + for name in partition_by + ] + ) + else: + partition_schema = pa.schema( + [schema.field(name) for name in partition_by] + ) + partitioning = ds.partitioning(partition_schema, flavor="hive") else: - partition_by = table.metadata().partition_columns - - if table.protocol().min_writer_version > MAX_SUPPORTED_WRITER_VERSION: - raise DeltaProtocolError( - "This table's min_writer_version is " - f"{table.protocol().min_writer_version}, " - "but this method only supports version 2." + partitioning = None + + add_actions: List[AddAction] = [] + + def visitor(written_file: Any) -> None: + path, partition_values = get_partitions_from_path(written_file.path) + stats = get_file_stats_from_metadata(written_file.metadata) + + # PyArrow added support for written_file.size in 9.0.0 + if PYARROW_MAJOR_VERSION >= 9: + size = written_file.size + elif filesystem is not None: + size = filesystem.get_file_info([path])[0].size + else: + size = 0 + + add_actions.append( + AddAction( + path, + size, + partition_values, + int(datetime.now().timestamp() * 1000), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) ) - else: # creating a new table - current_version = -1 - dtype_map = { - pa.large_string(): pa.string(), - } + if table is not None: + # We don't currently provide a way to set invariants + # (and maybe never will), so only enforce if already exist. + if table.protocol().min_writer_version > MAX_SUPPORTED_WRITER_VERSION: + raise DeltaProtocolError( + "This table's min_writer_version is " + f"{table.protocol().min_writer_version}, " + "but this method only supports version 2." + ) - def _large_to_normal_dtype(dtype: pa.DataType) -> pa.DataType: - try: - return dtype_map[dtype] - except KeyError: - return dtype - - if partition_by: - if PYARROW_MAJOR_VERSION < 12: - partition_schema = pa.schema( - [ - pa.field(name, _large_to_normal_dtype(schema.field(name).type)) - for name in partition_by - ] + invariants = table.schema().invariants + checker = _DeltaDataChecker(invariants) + + def check_data_is_aligned_with_partition_filtering( + batch: pa.RecordBatch, + ) -> None: + if table is None: + return + existed_partitions: FrozenSet[ + FrozenSet[Tuple[str, Optional[str]]] + ] = table._table.get_active_partitions() + allowed_partitions: FrozenSet[ + FrozenSet[Tuple[str, Optional[str]]] + ] = table._table.get_active_partitions(partition_filters) + partition_values = pa.RecordBatch.from_arrays( + [ + batch.column(column_name) + for column_name in table.metadata().partition_columns + ], + table.metadata().partition_columns, + ) + partition_values = batch_distinct(partition_values) + for i in range(partition_values.num_rows): + # Map will maintain order of partition_columns + partition_map = { + column_name: encode_partition_value( + batch.column(column_name)[i].as_py() + ) + for column_name in table.metadata().partition_columns + } + partition = frozenset(partition_map.items()) + if ( + partition not in allowed_partitions + and partition in existed_partitions + ): + partition_repr = " ".join( + f"{key}={value}" for key, value in partition_map.items() + ) + raise ValueError( + f"Data should be aligned with partitioning. " + f"Data contained values for partition {partition_repr}" + ) + + def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: + checker.check_batch(batch) + + if mode == "overwrite" and partition_filters: + check_data_is_aligned_with_partition_filtering(batch) + + return batch + + data = RecordBatchReader.from_batches( + schema, (validate_batch(batch) for batch in data) ) - else: - partition_schema = pa.schema([schema.field(name) for name in partition_by]) - partitioning = ds.partitioning(partition_schema, flavor="hive") - else: - partitioning = None - - add_actions: List[AddAction] = [] - - def visitor(written_file: Any) -> None: - path, partition_values = get_partitions_from_path(written_file.path) - stats = get_file_stats_from_metadata(written_file.metadata) - # PyArrow added support for written_file.size in 9.0.0 - if PYARROW_MAJOR_VERSION >= 9: - size = written_file.size - elif filesystem is not None: - size = filesystem.get_file_info([path])[0].size + if file_options is not None: + file_options.update(use_compliant_nested_type=False) else: - size = 0 - - add_actions.append( - AddAction( - path, - size, - partition_values, - int(datetime.now().timestamp() * 1000), - True, - json.dumps(stats, cls=DeltaJSONEncoder), - ) - ) - - if table is not None: - # We don't currently provide a way to set invariants - # (and maybe never will), so only enforce if already exist. - invariants = table.schema().invariants - checker = _DeltaDataChecker(invariants) - - def check_data_is_aligned_with_partition_filtering( - batch: pa.RecordBatch, - ) -> None: - if table is None: - return - existed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions() - allowed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions(partition_filters) - partition_values = pa.RecordBatch.from_arrays( - [ - batch.column(column_name) - for column_name in table.metadata().partition_columns - ], - table.metadata().partition_columns, + file_options = ds.ParquetFileFormat().make_write_options( + use_compliant_nested_type=False ) - partition_values = batch_distinct(partition_values) - for i in range(partition_values.num_rows): - # Map will maintain order of partition_columns - partition_map = { - column_name: encode_partition_value( - batch.column(column_name)[i].as_py() - ) - for column_name in table.metadata().partition_columns - } - partition = frozenset(partition_map.items()) - if ( - partition not in allowed_partitions - and partition in existed_partitions - ): - partition_repr = " ".join( - f"{key}={value}" for key, value in partition_map.items() - ) - raise ValueError( - f"Data should be aligned with partitioning. " - f"Data contained values for partition {partition_repr}" - ) - - def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: - checker.check_batch(batch) - - if mode == "overwrite" and partition_filters: - check_data_is_aligned_with_partition_filtering(batch) - - return batch - - data = RecordBatchReader.from_batches( - schema, (validate_batch(batch) for batch in data) - ) - if file_options is not None: - file_options.update(use_compliant_nested_type=False) - else: - file_options = ds.ParquetFileFormat().make_write_options( - use_compliant_nested_type=False + ds.write_dataset( + data, + base_dir="/", + basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet", + format="parquet", + partitioning=partitioning, + # It will not accept a schema if using a RBR + schema=schema if not isinstance(data, RecordBatchReader) else None, + file_visitor=visitor, + existing_data_behavior="overwrite_or_ignore", + file_options=file_options, + max_open_files=max_open_files, + max_rows_per_file=max_rows_per_file, + min_rows_per_group=min_rows_per_group, + max_rows_per_group=max_rows_per_group, + filesystem=filesystem, + max_partitions=max_partitions, ) - ds.write_dataset( - data, - base_dir="/", - basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet", - format="parquet", - partitioning=partitioning, - # It will not accept a schema if using a RBR - schema=schema if not isinstance(data, RecordBatchReader) else None, - file_visitor=visitor, - existing_data_behavior="overwrite_or_ignore", - file_options=file_options, - max_open_files=max_open_files, - max_rows_per_file=max_rows_per_file, - min_rows_per_group=min_rows_per_group, - max_rows_per_group=max_rows_per_group, - filesystem=filesystem, - max_partitions=max_partitions, - ) - - if table is None: - _write_new_deltalake( - table_uri, - schema, - add_actions, - mode, - partition_by or [], - name, - description, - configuration, - storage_options, - ) + if table is None: + write_deltalake_pyarrow( + table_uri, + schema, + add_actions, + mode, + partition_by or [], + name, + description, + configuration, + storage_options, + ) + else: + table._table.create_write_transaction( + add_actions, + mode, + partition_by or [], + schema, + partition_filters, + ) + table.update_incremental() else: - table._table.create_write_transaction( - add_actions, - mode, - partition_by or [], - schema, - partition_filters, - ) - table.update_incremental() + raise ValueError("Only `pyarrow` or `rust` are valid inputs for the engine.") def convert_to_deltalake( diff --git a/python/src/lib.rs b/python/src/lib.rs index 69195e866d..e7d5ec818d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1126,6 +1126,64 @@ impl From<&PyAddAction> for Add { } } +#[pyfunction] +#[allow(clippy::too_many_arguments)] +fn write_to_deltalake( + table_uri: String, + data: PyArrowType, + mode: String, + max_rows_per_group: i64, + overwrite_schema: bool, + partition_by: Option>, + predicate: Option, + name: Option, + description: Option, + configuration: Option>>, + storage_options: Option>, +) -> PyResult<()> { + let batches = data.0.map(|batch| batch.unwrap()).collect::>(); + let save_mode = mode.parse().map_err(PythonError::from)?; + + let options = storage_options.clone().unwrap_or_default(); + let table = rt()? + .block_on(DeltaOps::try_from_uri_with_storage_options( + &table_uri, options, + )) + .map_err(PythonError::from)?; + + let mut builder = table + .write(batches) + .with_save_mode(save_mode) + .with_overwrite_schema(overwrite_schema) + .with_write_batch_size(max_rows_per_group as usize); + + if let Some(partition_columns) = partition_by { + builder = builder.with_partition_columns(partition_columns); + } + + if let Some(name) = &name { + builder = builder.with_table_name(name); + }; + + if let Some(description) = &description { + builder = builder.with_description(description); + }; + + if let Some(predicate) = &predicate { + builder = builder.with_replace_where(predicate); + }; + + if let Some(config) = configuration { + builder = builder.with_configuration(config); + }; + + rt()? + .block_on(builder.into_future()) + .map_err(PythonError::from)?; + + Ok(()) +} + #[pyfunction] #[allow(clippy::too_many_arguments)] fn write_new_deltalake( @@ -1268,6 +1326,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add("__version__", env!("CARGO_PKG_VERSION"))?; m.add_function(pyo3::wrap_pyfunction!(rust_core_version, m)?)?; m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(write_to_deltalake, m)?)?; m.add_function(pyo3::wrap_pyfunction!(convert_to_deltalake, m)?)?; m.add_function(pyo3::wrap_pyfunction!(batch_distinct, m)?)?; m.add_class::()?; diff --git a/python/tests/test_benchmark.py b/python/tests/test_benchmark.py index fd32a7e4e6..d7299ca684 100644 --- a/python/tests/test_benchmark.py +++ b/python/tests/test_benchmark.py @@ -24,9 +24,12 @@ def sample_table() -> pa.Table: return tab +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.benchmark(group="write") -def test_benchmark_write(benchmark, sample_table, tmp_path): - benchmark(write_deltalake, str(tmp_path), sample_table, mode="overwrite") +def test_benchmark_write(benchmark, sample_table, tmp_path, engine): + benchmark( + write_deltalake, str(tmp_path), sample_table, mode="overwrite", engine=engine + ) dt = DeltaTable(str(tmp_path)) assert dt.to_pyarrow_table().sort_by("i") == sample_table diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 4330489e4a..0a63b16c70 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -6,7 +6,7 @@ import threading from datetime import date, datetime from math import inf -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Literal from unittest.mock import Mock import pyarrow as pa @@ -17,7 +17,7 @@ from pyarrow.lib import RecordBatchReader from deltalake import DeltaTable, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaProtocolError +from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -29,24 +29,30 @@ _has_pandas = True +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.skip(reason="Waiting on #570") -def test_handle_existing(tmp_path: pathlib.Path, sample_data: pa.Table): +def test_handle_existing( + tmp_path: pathlib.Path, sample_data: pa.Table, engine: Literal["pyarrow", "rust"] +): # if uri points to a non-empty directory that isn't a delta table, error tmp_path p = tmp_path / "hello.txt" p.write_text("hello") with pytest.raises(OSError) as exception: - write_deltalake(tmp_path, sample_data, mode="overwrite") + write_deltalake(tmp_path, sample_data, mode="overwrite", engine=engine) assert "directory is not empty" in str(exception) -def test_roundtrip_basic(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_basic( + tmp_path: pathlib.Path, sample_data: pa.Table, engine: Literal["pyarrow", "rust"] +): # Check we can create the subdirectory tmp_path = tmp_path / "path" / "to" / "table" start_time = datetime.now().timestamp() - write_deltalake(tmp_path, sample_data) + write_deltalake(tmp_path, sample_data, engine=engine) end_time = datetime.now().timestamp() assert ("0" * 20 + ".json") in os.listdir(tmp_path / "_delta_log") @@ -71,7 +77,8 @@ def test_roundtrip_basic(tmp_path: pathlib.Path, sample_data: pa.Table): assert modification_time < end_time -def test_roundtrip_nulls(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_nulls(tmp_path: pathlib.Path, engine: Literal["pyarrow", "rust"]): data = pa.table({"x": pa.array([None, None, 1, 2], type=pa.int64())}) # One row group will have values, one will be all nulls. # The first will have None in min and max stats, so we need to handle that. @@ -91,6 +98,7 @@ def test_roundtrip_nulls(tmp_path: pathlib.Path): min_rows_per_group=2, max_rows_per_group=2, mode="overwrite", + engine=engine, ) delta_table = DeltaTable(tmp_path) @@ -105,11 +113,23 @@ def test_enforce_schema(existing_table: DeltaTable, mode: str): bad_data = pa.table({"x": pa.array([1, 2, 3])}) with pytest.raises(ValueError): - write_deltalake(existing_table, bad_data, mode=mode) + write_deltalake(existing_table, bad_data, mode=mode, engine="pyarrow") table_uri = existing_table._table.table_uri() with pytest.raises(ValueError): - write_deltalake(table_uri, bad_data, mode=mode) + write_deltalake(table_uri, bad_data, mode=mode, engine="pyarrow") + + +@pytest.mark.parametrize("mode", ["append", "overwrite"]) +def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): + bad_data = pa.table({"x": pa.array([1, 2, 3])}) + + with pytest.raises(DeltaError): + write_deltalake(existing_table, bad_data, mode=mode, engine="rust") + + table_uri = existing_table._table.table_uri() + with pytest.raises(DeltaError): + write_deltalake(table_uri, bad_data, mode=mode, engine="rust") def test_update_schema(existing_table: DeltaTable): @@ -125,12 +145,59 @@ def test_update_schema(existing_table: DeltaTable): assert existing_table.schema().to_pyarrow() == new_data.schema -def test_local_path(tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch): +def test_update_schema_rust_writer(existing_table: DeltaTable): + new_data = pa.table({"x": pa.array([1, 2, 3])}) + + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + new_data, + mode="append", + overwrite_schema=True, + engine="rust", + ) + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + new_data, + mode="overwrite", + overwrite_schema=False, + engine="rust", + ) + with pytest.raises(DeltaError): + write_deltalake( + existing_table, + new_data, + mode="append", + overwrite_schema=False, + engine="rust", + ) + # TODO(ion): Remove this once we add schema overwrite support + write_deltalake( + existing_table, + new_data, + mode="overwrite", + overwrite_schema=True, + engine="rust", + ) + + read_data = existing_table.to_pyarrow_table() + assert new_data == read_data + assert existing_table.schema().to_pyarrow() == new_data.schema + + +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_local_path( + tmp_path: pathlib.Path, + sample_data: pa.Table, + monkeypatch, + engine: Literal["pyarrow", "rust"], +): monkeypatch.chdir(tmp_path) # Make tmp_path the working directory (tmp_path / "path/to/table").mkdir(parents=True) local_path = "./path/to/table" - write_deltalake(local_path, sample_data) + write_deltalake(local_path, sample_data, engine=engine) delta_table = DeltaTable(local_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -138,13 +205,15 @@ def test_local_path(tmp_path: pathlib.Path, sample_data: pa.Table, monkeypatch): assert table == sample_data -def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table, engine): write_deltalake( tmp_path, sample_data, name="test_name", description="test_desc", configuration={"delta.appendOnly": "false", "foo": "bar"}, + engine=engine, ) delta_table = DeltaTable(tmp_path) @@ -156,6 +225,7 @@ def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table): assert metadata.configuration == {"delta.appendOnly": "false", "foo": "bar"} +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.parametrize( "column", [ @@ -173,9 +243,9 @@ def test_roundtrip_metadata(tmp_path: pathlib.Path, sample_data: pa.Table): ], ) def test_roundtrip_partitioned( - tmp_path: pathlib.Path, sample_data: pa.Table, column: str + tmp_path: pathlib.Path, sample_data: pa.Table, column: str, engine ): - write_deltalake(tmp_path, sample_data, partition_by=column) + write_deltalake(tmp_path, sample_data, partition_by=column, engine=engine) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -189,11 +259,16 @@ def test_roundtrip_partitioned( assert add_path.count("/") == 1 -def test_roundtrip_null_partition(tmp_path: pathlib.Path, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_null_partition( + tmp_path: pathlib.Path, sample_data: pa.Table, engine +): sample_data = sample_data.add_column( 0, "utf8_with_nulls", pa.array(["a"] * 4 + [None]) ) - write_deltalake(tmp_path, sample_data, partition_by=["utf8_with_nulls"]) + write_deltalake( + tmp_path, sample_data, partition_by=["utf8_with_nulls"], engine=engine + ) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -203,8 +278,13 @@ def test_roundtrip_null_partition(tmp_path: pathlib.Path, sample_data: pa.Table) assert table == sample_data -def test_roundtrip_multi_partitioned(tmp_path: pathlib.Path, sample_data: pa.Table): - write_deltalake(tmp_path, sample_data, partition_by=["int32", "bool"]) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_roundtrip_multi_partitioned( + tmp_path: pathlib.Path, sample_data: pa.Table, engine +): + write_deltalake( + tmp_path, sample_data, partition_by=["int32", "bool"], engine=engine + ) delta_table = DeltaTable(tmp_path) assert delta_table.schema().to_pyarrow() == sample_data.schema @@ -218,33 +298,41 @@ def test_roundtrip_multi_partitioned(tmp_path: pathlib.Path, sample_data: pa.Tab assert add_path.count("/") == 2 -def test_write_modes(tmp_path: pathlib.Path, sample_data: pa.Table): - write_deltalake(tmp_path, sample_data) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_write_modes(tmp_path: pathlib.Path, sample_data: pa.Table, engine): + write_deltalake(tmp_path, sample_data, engine=engine) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data - with pytest.raises(AssertionError): - write_deltalake(tmp_path, sample_data, mode="error") + if engine == "pyarrow": + with pytest.raises(AssertionError): + write_deltalake(tmp_path, sample_data, mode="error") + elif engine == "rust": + with pytest.raises(DeltaError): + write_deltalake(tmp_path, sample_data, mode="error", engine="rust") - write_deltalake(tmp_path, sample_data, mode="ignore") + write_deltalake(tmp_path, sample_data, mode="ignore", engine="rust") assert ("0" * 19 + "1.json") not in os.listdir(tmp_path / "_delta_log") - write_deltalake(tmp_path, sample_data, mode="append") + write_deltalake(tmp_path, sample_data, mode="append", engine="rust") expected = pa.concat_tables([sample_data, sample_data]) assert DeltaTable(tmp_path).to_pyarrow_table() == expected - write_deltalake(tmp_path, sample_data, mode="overwrite") + write_deltalake(tmp_path, sample_data, mode="overwrite", engine="rust") assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data -def test_append_only_should_append_only_with_the_overwrite_mode( - tmp_path: pathlib.Path, sample_data: pa.Table +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_append_only_should_append_only_with_the_overwrite_mode( # Create rust equivalent rust + tmp_path: pathlib.Path, sample_data: pa.Table, engine ): config = {"delta.appendOnly": "true"} - write_deltalake(tmp_path, sample_data, mode="append", configuration=config) + write_deltalake( + tmp_path, sample_data, mode="append", configuration=config, engine=engine + ) table = DeltaTable(tmp_path) - write_deltalake(table, sample_data, mode="append") + write_deltalake(table, sample_data, mode="append", engine=engine) data_store_types = [tmp_path, table] fail_modes = ["overwrite", "ignore", "error"] @@ -257,7 +345,7 @@ def test_append_only_should_append_only_with_the_overwrite_mode( f" 'append'. Mode is currently {mode}" ), ): - write_deltalake(data_store_type, sample_data, mode=mode) + write_deltalake(data_store_type, sample_data, mode=mode, engine=engine) expected = pa.concat_tables([sample_data, sample_data]) @@ -265,21 +353,45 @@ def test_append_only_should_append_only_with_the_overwrite_mode( assert table.version() == 1 -def test_writer_with_table(existing_table: DeltaTable, sample_data: pa.Table): - write_deltalake(existing_table, sample_data, mode="overwrite") +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_writer_with_table(existing_table: DeltaTable, sample_data: pa.Table, engine): + write_deltalake(existing_table, sample_data, mode="overwrite", engine=engine) assert existing_table.to_pyarrow_table() == sample_data -def test_fails_wrong_partitioning(existing_table: DeltaTable, sample_data: pa.Table): - with pytest.raises(AssertionError): - write_deltalake( - existing_table, sample_data, mode="append", partition_by="int32" - ) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_fails_wrong_partitioning( + existing_table: DeltaTable, sample_data: pa.Table, engine +): + if engine == "pyarrow": + with pytest.raises(AssertionError): + write_deltalake( + existing_table, + sample_data, + mode="append", + partition_by="int32", + engine=engine, + ) + elif engine == "rust": + with pytest.raises( + DeltaError, + match='Generic error: Specified table partitioning does not match table partitioning: expected: [], got: ["int32"]', + ): + write_deltalake( + existing_table, + sample_data, + mode="append", + partition_by="int32", + engine=engine, + ) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.pandas @pytest.mark.parametrize("schema_provided", [True, False]) -def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table, schema_provided): +def test_write_pandas( + tmp_path: pathlib.Path, sample_data: pa.Table, schema_provided, engine +): # When timestamp is converted to Pandas, it gets casted to ns resolution, # but Delta Lake schemas only support us resolution. sample_pandas = sample_data.to_pandas() @@ -287,23 +399,27 @@ def test_write_pandas(tmp_path: pathlib.Path, sample_data: pa.Table, schema_prov schema = sample_data.schema else: schema = None - write_deltalake(tmp_path, sample_pandas, schema=schema) + write_deltalake(tmp_path, sample_pandas, schema=schema, engine=engine) delta_table = DeltaTable(tmp_path) df = delta_table.to_pandas() assert_frame_equal(df, sample_pandas) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) def test_write_iterator( - tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table + tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table, engine ): batches = existing_table.to_pyarrow_dataset().to_batches() with pytest.raises(ValueError): - write_deltalake(tmp_path, batches, mode="overwrite") + write_deltalake(tmp_path, batches, mode="overwrite", engine=engine) - write_deltalake(tmp_path, batches, schema=sample_data.schema, mode="overwrite") + write_deltalake( + tmp_path, batches, schema=sample_data.schema, mode="overwrite", engine=engine + ) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) @pytest.mark.parametrize("large_dtypes", [True, False]) @pytest.mark.parametrize( "constructor", @@ -317,38 +433,48 @@ def test_write_dataset_table_recordbatch( tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table, + engine: str, large_dtypes: bool, constructor, ): dataset = constructor(existing_table) - write_deltalake(tmp_path, dataset, mode="overwrite", large_dtypes=large_dtypes) + write_deltalake( + tmp_path, dataset, mode="overwrite", large_dtypes=large_dtypes, engine=engine + ) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data @pytest.mark.parametrize("large_dtypes", [True, False]) +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) def test_write_recordbatchreader( tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table, large_dtypes: bool, + engine: Literal["pyarrow", "rust"], ): batches = existing_table.to_pyarrow_dataset().to_batches() reader = RecordBatchReader.from_batches( existing_table.to_pyarrow_dataset().schema, batches ) - write_deltalake(tmp_path, reader, mode="overwrite", large_dtypes=large_dtypes) + write_deltalake( + tmp_path, reader, mode="overwrite", large_dtypes=large_dtypes, engine=engine + ) assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data -def test_writer_partitioning(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_writer_partitioning( + tmp_path: pathlib.Path, engine: Literal["pyarrow", "rust"] +): test_strings = ["a=b", "hello world", "hello%20world"] data = pa.table( {"p": pa.array(test_strings), "x": pa.array(range(len(test_strings)))} ) - write_deltalake(tmp_path, data) + write_deltalake(tmp_path, data, engine=engine) assert DeltaTable(tmp_path).to_pyarrow_table() == data @@ -437,7 +563,8 @@ def test_writer_stats(existing_table: DeltaTable, sample_data: pa.Table): assert stats["maxValues"] == expected_maxs -def test_writer_null_stats(tmp_path: pathlib.Path): +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_writer_null_stats(tmp_path: pathlib.Path, engine: Literal["pyarrow", "rust"]): data = pa.table( { "int32": pa.array([1, None, 2, None], pa.int32()), @@ -445,7 +572,7 @@ def test_writer_null_stats(tmp_path: pathlib.Path): "str": pa.array([None] * 4, pa.string()), } ) - write_deltalake(tmp_path, data) + write_deltalake(tmp_path, data, engine=engine) table = DeltaTable(tmp_path) stats = get_stats(table) @@ -454,10 +581,15 @@ def test_writer_null_stats(tmp_path: pathlib.Path): assert stats["nullCount"] == expected_nulls -def test_writer_fails_on_protocol(existing_table: DeltaTable, sample_data: pa.Table): +@pytest.mark.parametrize("engine", ["pyarrow"]) +def test_writer_fails_on_protocol( + existing_table: DeltaTable, + sample_data: pa.Table, + engine: Literal["pyarrow", "rust"], +): existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3)) with pytest.raises(DeltaProtocolError): - write_deltalake(existing_table, sample_data, mode="overwrite") + write_deltalake(existing_table, sample_data, mode="overwrite", engine=engine) @pytest.mark.parametrize( @@ -722,6 +854,74 @@ def test_partition_overwrite_unfiltered_data_fails( ) +@pytest.mark.parametrize( + "value_1,value_2,value_type,filter_string", + [ + (1, 2, pa.int64(), "1"), + (False, True, pa.bool_(), "false"), + (date(2022, 1, 1), date(2022, 1, 2), pa.date32(), "2022-01-01"), + ], +) +def test_replace_where_overwrite( + tmp_path: pathlib.Path, + value_1: Any, + value_2: Any, + value_type: pa.DataType, + filter_string: str, +): + sample_data = pa.table( + { + "p1": pa.array(["1", "1", "2", "2"], pa.string()), + "p2": pa.array([value_1, value_2, value_1, value_2], value_type), + "val": pa.array([1, 1, 1, 1], pa.int64()), + } + ) + write_deltalake(tmp_path, sample_data, mode="overwrite", partition_by=["p1", "p2"]) + + delta_table = DeltaTable(tmp_path) + assert ( + delta_table.to_pyarrow_table().sort_by( + [("p1", "ascending"), ("p2", "ascending")] + ) + == sample_data + ) + + sample_data = pa.table( + { + "p1": pa.array(["1", "1"], pa.string()), + "p2": pa.array([value_2, value_1], value_type), + "val": pa.array([2, 2], pa.int64()), + } + ) + expected_data = pa.table( + { + "p1": pa.array(["1", "1", "2", "2"], pa.string()), + "p2": pa.array([value_1, value_2, value_1, value_2], value_type), + "val": pa.array([2, 2, 1, 1], pa.int64()), + } + ) + + with pytest.raises( + DeltaError, + match="Generic DeltaTable error: Overwriting data based on predicate is not yet implemented", + ): + write_deltalake( + tmp_path, + sample_data, + mode="overwrite", + predicate="`p1` = 1", + engine="rust", + ) + + delta_table.update_incremental() + assert ( + delta_table.to_pyarrow_table().sort_by( + [("p1", "ascending"), ("p2", "ascending")] + ) + == expected_data + ) + + def test_partition_overwrite_with_new_partition( tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table ):