From d5792c813076e3e9391bbf78d2c7cc225bf66edd Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 14 Sep 2024 11:00:26 +0200 Subject: [PATCH] allow multiple features and add python tests --- crates/core/src/operations/add_feature.rs | 48 +++++++------- crates/core/src/protocol/mod.rs | 2 +- python/deltalake/_internal.pyi | 2 +- python/deltalake/table.py | 5 +- python/src/lib.rs | 5 +- python/tests/conftest.py | 7 ++ python/tests/test_alter.py | 81 ++++++++++++++++++++++- 7 files changed, 119 insertions(+), 31 deletions(-) diff --git a/crates/core/src/operations/add_feature.rs b/crates/core/src/operations/add_feature.rs index bab8176aa1..8c5715af96 100644 --- a/crates/core/src/operations/add_feature.rs +++ b/crates/core/src/operations/add_feature.rs @@ -1,9 +1,10 @@ //! Enable table features use futures::future::BoxFuture; +use itertools::Itertools; use super::transaction::{CommitBuilder, CommitProperties}; -use crate::kernel::TableFeatures; +use crate::kernel::{ReaderFeatures, TableFeatures, WriterFeatures}; use crate::logstore::LogStoreRef; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; @@ -15,7 +16,7 @@ pub struct AddTableFeatureBuilder { /// A snapshot of the table's state snapshot: DeltaTableState, /// Name of the feature - name: Option, + name: Vec, /// Allow protocol versions to be increased by setting features allow_protocol_versions_increase: bool, /// Delta object store for handling data files @@ -30,7 +31,7 @@ impl AddTableFeatureBuilder { /// Create a new builder pub fn new(log_store: LogStoreRef, snapshot: DeltaTableState) -> Self { Self { - name: None, + name: vec![], allow_protocol_versions_increase: false, snapshot, log_store, @@ -39,8 +40,9 @@ impl AddTableFeatureBuilder { } /// Specify the feature to be added - pub fn with_feature>(mut self, name: S) -> Self { - self.name = Some(name.into()); + pub fn with_feature>(mut self, name: Vec) -> Self { + self.name + .extend(name.into_iter().map(Into::into).collect_vec()); self } @@ -66,35 +68,35 @@ impl std::future::IntoFuture for AddTableFeatureBuilder { let this = self; Box::pin(async move { - let name = this - .name - .ok_or(DeltaTableError::Generic("No features provided".to_string()))?; - - let (reader_feature, writer_feature) = name.to_reader_writer_features(); + let name = if this.name.is_empty() { + return Err(DeltaTableError::Generic("No features provided".to_string())); + } else { + this.name + }; + let (reader_features, writer_features): ( + Vec>, + Vec>, + ) = name.iter().map(|v| v.to_reader_writer_features()).unzip(); + let reader_features = reader_features.into_iter().flatten().collect_vec(); + let writer_features = writer_features.into_iter().flatten().collect_vec(); let mut protocol = this.snapshot.protocol().clone(); if !this.allow_protocol_versions_increase { - if reader_feature.is_some() - && writer_feature.is_some() - && protocol.min_reader_version == 3 - && protocol.min_writer_version == 7 + if !reader_features.is_empty() + && !writer_features.is_empty() + && !(protocol.min_reader_version == 3 && protocol.min_writer_version == 7) { return Err(DeltaTableError::Generic("Table feature enables reader and writer feature, but reader is not v3, and writer not v7. Set allow_protocol_versions_increase or increase versions explicitly through set_tbl_properties".to_string())); - } else if reader_feature.is_some() && protocol.min_reader_version < 3 { + } else if !reader_features.is_empty() && protocol.min_reader_version < 3 { return Err(DeltaTableError::Generic("Table feature enables reader feature, but min_reader is not v3. Set allow_protocol_versions_increase or increase version explicitly through set_tbl_properties".to_string())); - } else if writer_feature.is_some() && protocol.min_writer_version < 7 { + } else if !writer_features.is_empty() && protocol.min_writer_version < 7 { return Err(DeltaTableError::Generic("Table feature enables writer feature, but min_writer is not v7. Set allow_protocol_versions_increase or increase version explicitly through set_tbl_properties".to_string())); } } - if let Some(reader_feature) = reader_feature { - protocol = protocol.with_reader_features(vec![reader_feature]); - } - - if let Some(writer_feature) = writer_feature { - protocol = protocol.with_writer_features(vec![writer_feature]); - } + protocol = protocol.with_reader_features(reader_features); + protocol = protocol.with_writer_features(writer_features); let operation = DeltaOperation::AddFeature { name }; diff --git a/crates/core/src/protocol/mod.rs b/crates/core/src/protocol/mod.rs index 9c3fb0608e..f82f48411a 100644 --- a/crates/core/src/protocol/mod.rs +++ b/crates/core/src/protocol/mod.rs @@ -367,7 +367,7 @@ pub enum DeltaOperation { /// Add table features to a table AddFeature { /// Name of the feature - name: TableFeatures, + name: Vec, }, /// Drops constraints from a table diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 034fd1af96..02a3765e02 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -125,7 +125,7 @@ class RawDeltaTable: ) -> None: ... def add_feature( self, - feature: TableFeatures, + feature: List[TableFeatures], allow_protocol_versions_increase: bool, commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f674ab6fd5..9150be697c 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1802,7 +1802,7 @@ def __init__(self, table: DeltaTable) -> None: def add_feature( self, - feature: TableFeatures, + feature: Union[TableFeatures, List[TableFeatures]], allow_protocol_versions_increase: bool = False, commit_properties: Optional[CommitProperties] = None, post_commithook_properties: Optional[PostCommitHookProperties] = None, @@ -1829,7 +1829,8 @@ def add_feature( ProtocolVersions(min_reader_version=1, min_writer_version=7, writer_features=['appendOnly'], reader_features=None) ``` """ - + if isinstance(feature, TableFeatures): + feature = [feature] self.table._table.add_feature( feature, allow_protocol_versions_increase, diff --git a/python/src/lib.rs b/python/src/lib.rs index 1134ca274f..eca97424f8 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -24,7 +24,6 @@ use deltalake::datafusion::physical_plan::ExecutionPlan; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; -use deltalake::kernel::TableFeatures as KernelTableFeatures; use deltalake::kernel::{ scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, }; @@ -572,7 +571,7 @@ impl RawDeltaTable { pub fn add_feature( &mut self, py: Python, - feature: TableFeatures, + feature: Vec, allow_protocol_versions_increase: bool, commit_properties: Option, post_commithook_properties: Option, @@ -582,7 +581,7 @@ impl RawDeltaTable { self._table.log_store(), self._table.snapshot().map_err(PythonError::from)?.clone(), ) - .with_feature(Into::::into(feature)) + .with_feature(feature) .with_allow_protocol_versions_increase(allow_protocol_versions_increase); if let Some(commit_properties) = diff --git a/python/tests/conftest.py b/python/tests/conftest.py index cd3dec4627..8f85f4ab04 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -228,6 +228,13 @@ def sample_table(): ) +@pytest.fixture() +def existing_sample_table(tmp_path: pathlib.Path, sample_table: pa.Table): + path = str(tmp_path) + write_deltalake(path, sample_table) + return DeltaTable(path) + + @pytest.fixture() def sample_table_with_spaces_numbers(): nrows = 5 diff --git a/python/tests/test_alter.py b/python/tests/test_alter.py index 65ac7e07ac..dc688effe1 100644 --- a/python/tests/test_alter.py +++ b/python/tests/test_alter.py @@ -4,7 +4,7 @@ import pyarrow as pa import pytest -from deltalake import DeltaTable, write_deltalake +from deltalake import DeltaTable, TableFeatures, write_deltalake from deltalake.exceptions import DeltaError, DeltaProtocolError from deltalake.schema import Field, PrimitiveType, StructType from deltalake.table import CommitProperties @@ -375,3 +375,82 @@ def test_add_timestamp_ntz_column(tmp_path: pathlib.Path, sample_table: pa.Table assert new_protocol.min_writer_version == 7 assert new_protocol.reader_features == ["timestampNtz"] assert new_protocol.writer_features == ["timestampNtz"] + + +features = [ + TableFeatures.ChangeDataFeed, + TableFeatures.DeletionVectors, + TableFeatures.ColumnMapping, + TableFeatures.TimestampWithoutTimezone, + TableFeatures.V2Checkpoint, + TableFeatures.AppendOnly, + TableFeatures.AppendOnly, + TableFeatures.Invariants, + TableFeatures.CheckConstraints, + TableFeatures.GeneratedColumns, + TableFeatures.IdentityColumns, + TableFeatures.RowTracking, + TableFeatures.DomainMetadata, + TableFeatures.IcebergCompatV1, +] + +all_features = [] +all_features.extend(features) +all_features.append(features) + + +@pytest.mark.parametrize("feature", all_features) +def test_add_feature_variations(existing_table: DeltaTable, feature): + """Existing table already has timestampNtz so it's already at v3,7""" + existing_table.alter.add_feature( + feature=feature, + allow_protocol_versions_increase=False, + ) + + +def test_add_features_disallowed_protocol_increase(existing_sample_table: DeltaTable): + with pytest.raises( + DeltaError, + match="Generic DeltaTable error: Table feature enables writer feature, but min_writer is not v7. Set allow_protocol_versions_increase or increase version explicitly through set_tbl_properties", + ): + existing_sample_table.alter.add_feature( + feature=TableFeatures.ChangeDataFeed, + allow_protocol_versions_increase=False, + ) + with pytest.raises( + DeltaError, + match="Generic DeltaTable error: Table feature enables reader and writer feature, but reader is not v3, and writer not v7. Set allow_protocol_versions_increase or increase versions explicitly through set_tbl_properties", + ): + existing_sample_table.alter.add_feature( + feature=TableFeatures.DeletionVectors, + allow_protocol_versions_increase=False, + ) + + +def test_add_feautres(existing_sample_table: DeltaTable): + existing_sample_table.alter.add_feature( + feature=features, + allow_protocol_versions_increase=True, + ) + protocol = existing_sample_table.protocol() + + assert sorted(protocol.reader_features) == sorted( # type: ignore + ["v2Checkpoint", "columnMapping", "deletionVectors", "timestampNtz"] + ) + assert sorted(protocol.writer_features) == sorted( # type: ignore + [ + "appendOnly", + "changeDataFeed", + "checkConstraints", + "columnMapping", + "deletionVectors", + "domainMetadata", + "generatedColumns", + "icebergCompatV1", + "identityColumns", + "invariants", + "rowTracking", + "timestampNtz", + "v2Checkpoint", + ] + ) # type: ignore