Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): add support for pyarrow 13+ #1804

Merged
merged 13 commits into from
Nov 5, 2023
1 change: 0 additions & 1 deletion crates/deltalake-core/src/operations/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,6 @@ pub(super) mod zorder {
assert_eq!(result.null_count(), 0);

let data: &BinaryArray = as_generic_binary_array(result.as_ref());
dbg!(data);
assert_eq!(data.value_data().len(), 3 * 16 * 3);
assert!(data.iter().all(|x| x.unwrap().len() == 3 * 16));
}
Expand Down
15 changes: 15 additions & 0 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,3 +1435,18 @@ def z_order(
)
self.table.update_incremental()
return json.loads(metrics)


def _cast_to_equal_batch(
batch: pyarrow.RecordBatch, schema: pyarrow.Schema
) -> pyarrow.RecordBatch:
"""
Cast a batch to a schema, if it is already considered equal.

This is mostly for mapping things like list field names, which arrow-rs
checks when looking at schema equality, but pyarrow does not.
"""
if batch.schema == schema:
return pyarrow.Table.from_batches([batch]).cast(schema).to_batches()[0]
else:
return batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ion-elgreco If we take this out, do the tests still pass?

Copy link
Collaborator Author

@ion-elgreco ion-elgreco Nov 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, everything passes still. Shall I take it out?

Or would we require it later once we enable the compliant types again.

edit: Taking it out for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I think we shouldn’t need it. This was an earlier attempt at fixing the compliant types thing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have backwards compatibility during reads and writes when we use the compliant types, I think spark-delta also doesn't write these compliant types but I may be wrong here

10 changes: 9 additions & 1 deletion python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ._internal import batch_distinct
from ._internal import write_new_deltalake as _write_new_deltalake
from .exceptions import DeltaProtocolError, TableNotFoundError
from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable
from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable, _cast_to_equal_batch

try:
import pandas as pd # noqa: F811
Expand Down Expand Up @@ -320,6 +320,7 @@ def check_data_is_aligned_with_partition_filtering(
)

def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
batch = _cast_to_equal_batch(batch, schema)
checker.check_batch(batch)

if mode == "overwrite" and partition_filters:
Expand All @@ -342,6 +343,13 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
schema, (validate_batch(batch) for batch in batch_iter)
)

if file_options is not None:
file_options.update(use_compliant_nested_type=False)
else:
file_options = ds.ParquetFileFormat().make_write_options(
use_compliant_nested_type=False
)

ds.write_dataset(
data,
base_dir="/",
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Programming Language :: Python :: 3.11"
]
dependencies = [
"pyarrow>=8,<13",
"pyarrow>=8",
'typing-extensions;python_version<"3.8"',
]

Expand Down
36 changes: 26 additions & 10 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::time;
use std::time::{SystemTime, UNIX_EPOCH};

use arrow::pyarrow::PyArrowType;
use arrow_schema::DataType;
use chrono::{DateTime, Duration, FixedOffset, Utc};
use deltalake::arrow::compute::concat_batches;
use deltalake::arrow::ffi_stream::ArrowArrayStreamReader;
Expand Down Expand Up @@ -947,16 +948,31 @@ fn filestats_to_expression<'py>(
let mut expressions: Vec<PyResult<&PyAny>> = Vec::new();

let cast_to_type = |column_name: &String, value: PyObject, schema: &ArrowSchema| {
let column_type = PyArrowType(
schema
.field_with_name(column_name)
.map_err(|_| {
PyValueError::new_err(format!("Column not found in schema: {column_name}"))
})?
.data_type()
.clone(),
)
.into_py(py);
let column_type = schema
.field_with_name(column_name)
.map_err(|_| {
PyValueError::new_err(format!("Column not found in schema: {column_name}"))
})?
.data_type()
.clone();

let value = match column_type {
// Since PyArrow 13.0.0, casting string -> timestamp fails if it ends with "Z"
// and the target type is timezone naive.
DataType::Timestamp(_, _) if value.extract::<String>(py).is_ok() => {
value.call_method1(py, "rstrip", ("Z",))?
}
// PyArrow 13.0.0 lost the ability to cast from string to date32, so
// we have to implement that manually.
DataType::Date32 if value.extract::<String>(py).is_ok() => {
let date = Python::import(py, "datetime")?.getattr("date")?;
let date = date.call_method1("fromisoformat", (value,))?;
date.to_object(py)
}
_ => value,
};

let column_type = PyArrowType(column_type).into_py(py);
pa.call_method1("scalar", (value,))?
.call_method1("cast", (column_type,))
};
Expand Down
6 changes: 4 additions & 2 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def test_write_recordbatchreader(
tmp_path: pathlib.Path, existing_table: DeltaTable, sample_data: pa.Table
):
batches = existing_table.to_pyarrow_dataset().to_batches()
reader = RecordBatchReader.from_batches(sample_data.schema, batches)
reader = RecordBatchReader.from_batches(
existing_table.to_pyarrow_dataset().schema, batches
)

write_deltalake(tmp_path, reader, mode="overwrite")
assert DeltaTable(tmp_path).to_pyarrow_table() == sample_data
Expand Down Expand Up @@ -892,7 +894,7 @@ def comp():
# concurrently, then this will fail.
assert data.num_rows == sample_data.num_rows
try:
write_deltalake(dt.table_uri, data, mode="overwrite")
write_deltalake(dt.table_uri, sample_data, mode="overwrite")
except Exception as e:
exception = e

Expand Down
Loading