From 12979dd8816f8ba481452c00d853db8516634bc3 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:43:14 +0200 Subject: [PATCH] fix(python, rust): check timestamp_ntz in nested fields, add check_can_writestamp_ntz in pyarrow writer (#2443) # Description The nested fields weren't checked, which meant you could get a timestampNtz in your schema but not have the reader/writer features set. This check is now done recursively. --- crates/core/src/operations/create.rs | 8 +--- .../src/operations/transaction/protocol.rs | 27 +++++++++--- python/Cargo.toml | 2 +- python/deltalake/_internal.pyi | 1 + python/deltalake/writer.py | 1 + python/src/lib.rs | 13 +++++- python/tests/test_writer.py | 42 +++++++++++++++++++ 7 files changed, 80 insertions(+), 14 deletions(-) diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index f9a7f62183..62f803ae3c 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -235,16 +235,12 @@ impl CreateBuilder { ) }; - let contains_timestampntz = &self - .columns - .iter() - .any(|f| f.data_type() == &DataType::TIMESTAMPNTZ); - + let contains_timestampntz = PROTOCOL.contains_timestampntz(&self.columns); // TODO configure more permissive versions based on configuration. Also how should this ideally be handled? // We set the lowest protocol we can, and if subsequent writes use newer features we update metadata? let (min_reader_version, min_writer_version, writer_features, reader_features) = - if *contains_timestampntz { + if contains_timestampntz { let mut converted_writer_features = self .configuration .keys() diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index 7971ac883a..95a0e22d66 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -4,7 +4,9 @@ use lazy_static::lazy_static; use once_cell::sync::Lazy; use super::{TableReference, TransactionError}; -use crate::kernel::{Action, DataType, EagerSnapshot, ReaderFeatures, Schema, WriterFeatures}; +use crate::kernel::{ + Action, DataType, EagerSnapshot, ReaderFeatures, Schema, StructField, WriterFeatures, +}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; @@ -77,17 +79,30 @@ impl ProtocolChecker { Ok(()) } + /// checks if table contains timestamp_ntz in any field including nested fields. + pub fn contains_timestampntz(&self, fields: &Vec) -> bool { + fn check_vec_fields(fields: &Vec) -> bool { + fields.iter().any(|f| _check_type(f.data_type())) + } + + fn _check_type(dtype: &DataType) -> bool { + match dtype { + &DataType::TIMESTAMPNTZ => true, + DataType::Array(inner) => _check_type(inner.element_type()), + DataType::Struct(inner) => check_vec_fields(inner.fields()), + _ => false, + } + } + check_vec_fields(fields) + } + /// Check can write_timestamp_ntz pub fn check_can_write_timestamp_ntz( &self, snapshot: &DeltaTableState, schema: &Schema, ) -> Result<(), TransactionError> { - let contains_timestampntz = schema - .fields() - .iter() - .any(|f| f.data_type() == &DataType::TIMESTAMPNTZ); - + let contains_timestampntz = self.contains_timestampntz(schema.fields()); let required_features: Option<&HashSet> = match snapshot.protocol().min_writer_version { 0..=6 => None, diff --git a/python/Cargo.toml b/python/Cargo.toml index cdf393957d..1589ecb11e 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.17.0" +version = "0.17.1" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index d100cdb11f..7d2e6342a2 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -154,6 +154,7 @@ class RawDeltaTable: custom_metadata: Optional[Dict[str, str]], ) -> None: ... def cleanup_metadata(self) -> None: ... + def check_can_write_timestamp_ntz(self, schema: pyarrow.Schema) -> None: ... def rust_core_version() -> str: ... def write_new_deltalake( diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 7924b072b9..2bc299215b 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -429,6 +429,7 @@ def visitor(written_file: Any) -> None: # We don't currently provide a way to set invariants # (and maybe never will), so only enforce if already exist. table_protocol = table.protocol() + table._table.check_can_write_timestamp_ntz(schema) if ( table_protocol.min_writer_version > MAX_SUPPORTED_PYARROW_WRITER_VERSION or table_protocol.min_writer_version diff --git a/python/src/lib.rs b/python/src/lib.rs index a64a5efe84..917bbb5750 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -33,7 +33,7 @@ use deltalake::operations::filesystem_check::FileSystemCheckBuilder; use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; -use deltalake::operations::transaction::{CommitBuilder, CommitProperties}; +use deltalake::operations::transaction::{CommitBuilder, CommitProperties, PROTOCOL}; use deltalake::operations::update::UpdateBuilder; use deltalake::operations::vacuum::VacuumBuilder; use deltalake::parquet::basic::Compression; @@ -175,6 +175,17 @@ impl RawDeltaTable { )) } + pub fn check_can_write_timestamp_ntz(&self, schema: PyArrowType) -> PyResult<()> { + let schema: StructType = (&schema.0).try_into().map_err(PythonError::from)?; + Ok(PROTOCOL + .check_can_write_timestamp_ntz( + self._table.snapshot().map_err(PythonError::from)?, + &schema, + ) + .map_err(|e| DeltaTableError::Generic(e.to_string())) + .map_err(PythonError::from)?) + } + pub fn load_version(&mut self, version: i64) -> PyResult<()> { Ok(rt() .block_on(self._table.load_version(version)) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index eb8244dbb3..4035191c85 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1528,3 +1528,45 @@ def test_rust_decimal_cast(tmp_path: pathlib.Path): write_deltalake( tmp_path, data, mode="append", schema_mode="merge", engine="rust" ) + + +@pytest.mark.parametrize( + "array", + [ + pa.array([[datetime(2010, 1, 1)]]), + pa.array([{"foo": datetime(2010, 1, 1)}]), + pa.array([{"foo": [[datetime(2010, 1, 1)]]}]), + pa.array([{"foo": [[{"foo": datetime(2010, 1, 1)}]]}]), + ], +) +def test_write_timestamp_ntz_nested(tmp_path: pathlib.Path, array: pa.array): + data = pa.table({"x": array}) + write_deltalake(tmp_path, data, mode="append", engine="rust") + + dt = DeltaTable(tmp_path) + + protocol = dt.protocol() + assert protocol.min_reader_version == 3 + assert protocol.min_writer_version == 7 + assert protocol.reader_features == ["timestampNtz"] + assert protocol.writer_features == ["timestampNtz"] + + +def test_write_timestamp_ntz_on_table_with_features_not_enabled(tmp_path: pathlib.Path): + data = pa.table({"x": pa.array(["foo"])}) + write_deltalake(tmp_path, data, mode="append", engine="pyarrow") + + dt = DeltaTable(tmp_path) + + protocol = dt.protocol() + assert protocol.min_reader_version == 1 + assert protocol.min_writer_version == 2 + + data = pa.table({"x": pa.array([datetime(2010, 1, 1)])}) + with pytest.raises( + DeltaError, + match="Generic DeltaTable error: Writer features must be specified for writerversion >= 7, please specify: TimestampWithoutTimezone", + ): + write_deltalake( + tmp_path, data, mode="overwrite", engine="pyarrow", schema_mode="overwrite" + )