From b7f75ddb52da6f3ebf083e9eff4efdd718975afd Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Mon, 13 Jan 2025 17:04:24 +0100 Subject: [PATCH] chore: tests Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/kernel/models/schema.rs | 44 ++++- crates/core/src/operations/merge/mod.rs | 7 +- python/tests/test_generated_columns.py | 220 ++++++++++++++++++++++++ 3 files changed, 267 insertions(+), 4 deletions(-) create mode 100644 python/tests/test_generated_columns.py diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 7c6be8a7dd..bd76f0b3e9 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -100,7 +100,6 @@ impl StructTypeExt for StructType { }; } } - dbg!(generated_cols.clone()); Ok(generated_cols) } @@ -183,6 +182,49 @@ mod tests { use serde_json; use serde_json::json; + #[test] + fn test_get_generated_columns() { + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{}}] + } + )) + .unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 0); + + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}] + } + )).unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 1); + assert_eq!(cols[0].data_type, DataType::INTEGER); + assert_eq!( + cols[0].validation_expr, + "gc = 5 OR (gc IS NULL AND 5 IS NULL)" + ); + + let schema: StructType = serde_json::from_value(json!( + { + "type":"struct", + "fields":[ + {"name":"id","type":"integer","nullable":true,"metadata":{}}, + {"name":"gc","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"5\""}}, + {"name":"id2","type":"integer","nullable":true,"metadata":{"delta.generationExpression":"\"id * 10\""}},] + } + )).unwrap(); + let cols = schema.get_generated_columns().unwrap(); + assert_eq!(cols.len(), 2); + } + #[test] fn test_get_invariants() { let schema: StructType = serde_json::from_value(json!({ diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 0e2541349c..e58fd22664 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -761,11 +761,12 @@ async fn execute( for generated_col in generated_cols { let col_name = generated_col.get_name(); - if !df + if df .clone() .schema() - .field_names() - .contains(&col_name.to_string()) + .field_with_unqualified_name(&col_name.to_string()) + .is_err() + // implies it doesn't exist { debug!( "Adding missing generated column {} in source as placeholder", diff --git a/python/tests/test_generated_columns.py b/python/tests/test_generated_columns.py new file mode 100644 index 0000000000..b329c948be --- /dev/null +++ b/python/tests/test_generated_columns.py @@ -0,0 +1,220 @@ +import pyarrow as pa +import pytest + +from deltalake import DeltaTable, Field, Schema, write_deltalake +from deltalake.exceptions import DeltaError, SchemaMismatchError +from deltalake.schema import PrimitiveType + + +@pytest.fixture +def gc_schema() -> Schema: + return Schema( + [ + Field(name="id", type=PrimitiveType("integer")), + Field( + name="gc", + type=PrimitiveType("integer"), + metadata={"delta.generationExpression": "'5'"}, + ), + ] + ) + + +@pytest.fixture +def valid_gc_data() -> pa.Table: + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"}) + data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [10, 10]}, schema=pa.schema([id_col, gc]) + ) + return data + + +@pytest.fixture +def data_without_gc() -> pa.Table: + id_col = pa.field("id", pa.int32()) + data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col])) + return data + + +@pytest.fixture +def invalid_gc_data() -> pa.Table: + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()).with_metadata({"delta.generationExpression": "10"}) + data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 10]}, schema=pa.schema([id_col, gc]) + ) + return data + + +@pytest.fixture +def table_with_gc(tmp_path, gc_schema) -> DeltaTable: + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + ) + return dt + + +def test_create_table_with_generated_columns(tmp_path, gc_schema: Schema): + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + ) + protocol = dt.protocol() + assert protocol.min_writer_version == 4 + + dt = DeltaTable.create( + tmp_path, + schema=gc_schema, + mode="overwrite", + configuration={"delta.minWriterVersion": "7"}, + ) + protocol = dt.protocol() + + assert dt.version() == 1 + assert protocol.writer_features is not None + assert "generatedColumns" in protocol.writer_features + + +def test_write_with_gc(tmp_path, valid_gc_data): + write_deltalake(tmp_path, mode="append", data=valid_gc_data) + dt = DeltaTable(tmp_path) + + assert dt.protocol().min_writer_version == 4 + assert dt.to_pyarrow_table() == valid_gc_data + + +def test_write_with_gc_higher_writer_version(tmp_path, valid_gc_data): + write_deltalake( + tmp_path, + mode="append", + data=valid_gc_data, + configuration={"delta.minWriterVersion": "7"}, + ) + dt = DeltaTable(tmp_path) + protocol = dt.protocol() + assert protocol.min_writer_version == 7 + assert protocol.writer_features is not None + assert "generatedColumns" in protocol.writer_features + assert dt.to_pyarrow_table() == valid_gc_data + + +def test_write_with_invalid_gc(tmp_path, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + 'Invariant violations: ["Check or Invariant (gc = 10 OR (gc IS NULL AND 10 IS NULL)) violated by value in row: [5]"]' + ), + ): + write_deltalake(tmp_path, mode="append", data=invalid_gc_data) + + +def test_write_with_invalid_gc_to_table(table_with_gc, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + "Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]" + ), + ): + write_deltalake(table_with_gc, mode="append", data=invalid_gc_data) + + +def test_write_to_table_generating_data(table_with_gc: DeltaTable): + id_col = pa.field("id", pa.int32()) + data = pa.Table.from_pydict({"id": [1, 2]}, schema=pa.schema([id_col])) + write_deltalake(table_with_gc, mode="append", data=data) + + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + + assert table_with_gc.version() == 1 + assert table_with_gc.to_pyarrow_table() == expected_data + + +def test_raise_when_gc_passed_during_schema_evolution( + tmp_path, data_without_gc, valid_gc_data +): + write_deltalake( + tmp_path, + mode="append", + data=data_without_gc, + ) + dt = DeltaTable(tmp_path) + assert dt.protocol().min_writer_version == 2 + + with pytest.raises( + SchemaMismatchError, + match="Schema evolved fields cannot have generated expressions. Recreate the table to achieve this.", + ): + write_deltalake( + dt, + mode="append", + data=valid_gc_data, + schema_mode="merge", + ) + + +def test_raise_when_gc_passed_during_adding_new_columns(tmp_path, data_without_gc): + write_deltalake( + tmp_path, + mode="append", + data=data_without_gc, + ) + dt = DeltaTable(tmp_path) + assert dt.protocol().min_writer_version == 2 + + with pytest.raises(DeltaError, match="New columns cannot be a generated column"): + dt.alter.add_columns( + fields=[ + Field( + name="gc", + type=PrimitiveType("integer"), + metadata={"delta.generationExpression": "'5'"}, + ) + ] + ) + + +def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc): + ( + table_with_gc.merge( + data_without_gc, predicate="s.id = t.id", source_alias="s", target_alias="t" + ) + .when_not_matched_insert_all() + .execute() + ) + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + assert table_with_gc.to_pyarrow_table() == expected_data + + +def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data): + import re + + with pytest.raises( + DeltaError, + match=re.escape( + "Invariant violations: [\"Check or Invariant (gc = '5' OR (gc IS NULL AND '5' IS NULL)) violated by value in row: [10]\"]" + ), + ): + ( + table_with_gc.merge( + invalid_gc_data, + predicate="s.id = t.id", + source_alias="s", + target_alias="t", + ) + .when_not_matched_insert_all() + .execute() + )