Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): add support for pyarrow 13+ #1804

Merged
merged 13 commits into from
Nov 5, 2023
1 change: 0 additions & 1 deletion crates/deltalake-core/src/operations/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,6 @@ pub(super) mod zorder {
assert_eq!(result.null_count(), 0);

let data: &BinaryArray = as_generic_binary_array(result.as_ref());
dbg!(data);
assert_eq!(data.value_data().len(), 3 * 16 * 3);
assert!(data.iter().all(|x| x.unwrap().len() == 3 * 16));
}
Expand Down
7 changes: 7 additions & 0 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
schema, (validate_batch(batch) for batch in batch_iter)
)

if file_options is not None:
file_options.update(use_compliant_nested_type=False)
else:
file_options = ds.ParquetFileFormat().make_write_options(
use_compliant_nested_type=False
)

ds.write_dataset(
data,
base_dir="/",
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Programming Language :: Python :: 3.11"
]
dependencies = [
"pyarrow>=8,<13",
"pyarrow>=8",
'typing-extensions;python_version<"3.8"',
]

Expand Down
36 changes: 26 additions & 10 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::time;
use std::time::{SystemTime, UNIX_EPOCH};

use arrow::pyarrow::PyArrowType;
use arrow_schema::DataType;
use chrono::{DateTime, Duration, FixedOffset, Utc};
use deltalake::arrow::compute::concat_batches;
use deltalake::arrow::ffi_stream::ArrowArrayStreamReader;
Expand Down Expand Up @@ -947,16 +948,31 @@ fn filestats_to_expression<'py>(
let mut expressions: Vec<PyResult<&PyAny>> = Vec::new();

let cast_to_type = |column_name: &String, value: PyObject, schema: &ArrowSchema| {
let column_type = PyArrowType(
schema
.field_with_name(column_name)
.map_err(|_| {
PyValueError::new_err(format!("Column not found in schema: {column_name}"))
})?
.data_type()
.clone(),
)
.into_py(py);
let column_type = schema
.field_with_name(column_name)
.map_err(|_| {
PyValueError::new_err(format!("Column not found in schema: {column_name}"))
})?
.data_type()
.clone();

let value = match column_type {
// Since PyArrow 13.0.0, casting string -> timestamp fails if it ends with "Z"
// and the target type is timezone naive.
DataType::Timestamp(_, _) if value.extract::<String>(py).is_ok() => {
value.call_method1(py, "rstrip", ("Z",))?
}
// PyArrow 13.0.0 lost the ability to cast from string to date32, so
// we have to implement that manually.
DataType::Date32 if value.extract::<String>(py).is_ok() => {
let date = Python::import(py, "datetime")?.getattr("date")?;
let date = date.call_method1("fromisoformat", (value,))?;
date.to_object(py)
}
_ => value,
};

let column_type = PyArrowType(column_type).into_py(py);
pa.call_method1("scalar", (value,))?
.call_method1("cast", (column_type,))
};
Expand Down
6 changes: 4 additions & 2 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def test_write_recordbatchreader(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
batches = existing_table.to_pyarrow_dataset().to_batches()
reader = RecordBatchReader.from_batches(sample_data.schema, batches)
reader = RecordBatchReader.from_batches(
existing_table.to_pyarrow_dataset().schema, batches
)

write_deltalake(tmp_path, reader, mode="overwrite")
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data
Expand Down Expand Up @@ -892,7 +894,7 @@ def comp():
# concurrently, then this will fail.
assert data.num_rows == sample_data.num_rows
try:
write_deltalake(dt.table_uri, data, mode="overwrite")
write_deltalake(dt.table_uri, sample_data, mode="overwrite")
except Exception as e:
exception = e

Expand Down
Loading