From d98a0c22efb98368c429610e89acf9bd1e45e299 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 5 Nov 2023 01:22:42 +0100 Subject: [PATCH] fix(python): add support for pyarrow 13+ (#1804) # Description I build on top of the branch of @wjones127 https://github.com/delta-io/delta-rs/pull/1602. In pyarrow v13+ the ParquetWriter by default uses the `compliant_nested_types = True` (see related PR: https://github.com/apache/arrow/pull/35146/files)and the docs: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html). In arrow/parquet-rs it fails when it compares schemas because it expected the old non-compliant ones. For now we can have pyarrow 13+ supported by disabling it or updating the file options provided by a user. # Related Issue(s) - Closes https://github.com/delta-io/delta-rs/issues/1744 # Documentation --------- Co-authored-by: Will Jones Co-authored-by: R. Tyler Croy --- .../deltalake-core/src/operations/optimize.rs | 1 - python/deltalake/writer.py | 7 ++++ python/pyproject.toml | 2 +- python/src/lib.rs | 36 +++++++++++++------ python/tests/test_writer.py | 6 ++-- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/crates/deltalake-core/src/operations/optimize.rs b/crates/deltalake-core/src/operations/optimize.rs index 65b3731f57..ae9ab6cd65 100644 --- a/crates/deltalake-core/src/operations/optimize.rs +++ b/crates/deltalake-core/src/operations/optimize.rs @@ -1432,7 +1432,6 @@ pub(super) mod zorder { assert_eq!(result.null_count(), 0); let data: &BinaryArray = as_generic_binary_array(result.as_ref()); - dbg!(data); assert_eq!(data.value_data().len(), 3 * 16 * 3); assert!(data.iter().all(|x| x.unwrap().len() == 3 * 16)); } diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 80fb245c2c..ef4ae3a57b 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -342,6 +342,13 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: schema, (validate_batch(batch) for batch in batch_iter) ) + if file_options is not None: + file_options.update(use_compliant_nested_type=False) + else: + file_options = ds.ParquetFileFormat().make_write_options( + use_compliant_nested_type=False + ) + ds.write_dataset( data, base_dir="/", diff --git a/python/pyproject.toml b/python/pyproject.toml index 5f7a4a5e9c..438a49cc56 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Programming Language :: Python :: 3.11" ] dependencies = [ - "pyarrow>=8,<13", + "pyarrow>=8", 'typing-extensions;python_version<"3.8"', ] diff --git a/python/src/lib.rs b/python/src/lib.rs index 93f71597ba..cc6b2202c3 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -13,6 +13,7 @@ use std::time; use std::time::{SystemTime, UNIX_EPOCH}; use arrow::pyarrow::PyArrowType; +use arrow_schema::DataType; use chrono::{DateTime, Duration, FixedOffset, Utc}; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; @@ -947,16 +948,31 @@ fn filestats_to_expression<'py>( let mut expressions: Vec> = Vec::new(); let cast_to_type = |column_name: &String, value: PyObject, schema: &ArrowSchema| { - let column_type = PyArrowType( - schema - .field_with_name(column_name) - .map_err(|_| { - PyValueError::new_err(format!("Column not found in schema: {column_name}")) - })? - .data_type() - .clone(), - ) - .into_py(py); + let column_type = schema + .field_with_name(column_name) + .map_err(|_| { + PyValueError::new_err(format!("Column not found in schema: {column_name}")) + })? + .data_type() + .clone(); + + let value = match column_type { + // Since PyArrow 13.0.0, casting string -> timestamp fails if it ends with "Z" + // and the target type is timezone naive. + DataType::Timestamp(_, _) if value.extract::(py).is_ok() => { + value.call_method1(py, "rstrip", ("Z",))? + } + // PyArrow 13.0.0 lost the ability to cast from string to date32, so + // we have to implement that manually. + DataType::Date32 if value.extract::(py).is_ok() => { + let date = Python::import(py, "datetime")?.getattr("date")?; + let date = date.call_method1("fromisoformat", (value,))?; + date.to_object(py) + } + _ => value, + }; + + let column_type = PyArrowType(column_type).into_py(py); pa.call_method1("scalar", (value,))? .call_method1("cast", (column_type,)) }; diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index f63a1a51b9..d048f8b79b 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -308,7 +308,9 @@ def test_write_recordbatchreader( tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table ): batches = existing_table.to_pyarrow_dataset().to_batches() - reader = RecordBatchReader.from_batches(sample_data.schema, batches) + reader = RecordBatchReader.from_batches( + existing_table.to_pyarrow_dataset().schema, batches + ) write_deltalake(tmp_path, reader, mode="overwrite") assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data @@ -892,7 +894,7 @@ def comp(): # concurrently, then this will fail. assert data.num_rows == sample_data.num_rows try: - write_deltalake(dt.table_uri, data, mode="overwrite") + write_deltalake(dt.table_uri, sample_data, mode="overwrite") except Exception as e: exception = e