diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 8507bed33a..461e66db58 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -94,13 +94,13 @@ tracing = { workspace = true } rand = "0.8" z85 = "3.0.5" maplit = "1" +sqlparser = { version = "0.49" } # Unity reqwest = { version = "0.11.18", default-features = false, features = [ "rustls-tls", "json", ], optional = true } -sqlparser = { version = "0.49", optional = true } [dev-dependencies] criterion = "0.5" @@ -130,7 +130,6 @@ datafusion = [ "datafusion-sql", "datafusion-functions", "datafusion-functions-aggregate", - "sqlparser", ] datafusion-ext = ["datafusion"] json = ["parquet/json"] diff --git a/crates/core/src/kernel/snapshot/mod.rs b/crates/core/src/kernel/snapshot/mod.rs index d4a8a671a7..1781e81cc1 100644 --- a/crates/core/src/kernel/snapshot/mod.rs +++ b/crates/core/src/kernel/snapshot/mod.rs @@ -315,7 +315,7 @@ impl Snapshot { let stats_fields = if let Some(stats_cols) = self.table_config().stats_columns() { stats_cols .iter() - .map(|col| match schema.field(col) { + .map(|col| match get_stats_field(schema, col) { Some(field) => match field.data_type() { DataType::Map(_) | DataType::Array(_) | &DataType::BINARY => { Err(DeltaTableError::Generic(format!( @@ -763,6 +763,45 @@ mod datafusion { } } +/// Retrieves a specific field from the schema based on the provided field name. +/// It handles cases where the field name is nested or enclosed in backticks. +fn get_stats_field<'a>(schema: &'a StructType, stats_field_name: &str) -> Option<&'a StructField> { + let dialect = sqlparser::dialect::GenericDialect {}; + match sqlparser::parser::Parser::new(&dialect).try_with_sql(stats_field_name) { + Ok(mut parser) => match parser.parse_multipart_identifier() { + Ok(parts) => find_nested_field(schema, &parts), + Err(_) => schema.field(stats_field_name), + }, + Err(_) => schema.field(stats_field_name), + } +} + +fn find_nested_field<'a>( + schema: &'a StructType, + parts: &[sqlparser::ast::Ident], +) -> Option<&'a StructField> { + if parts.is_empty() { + return None; + } + let part_name = &parts[0].value; + match schema.field(part_name) { + Some(field) => { + if parts.len() == 1 { + Some(field) + } else { + match field.data_type() { + DataType::Struct(struct_schema) => { + find_nested_field(struct_schema, &parts[1..]) + } + // Any part before the end must be a struct + _ => None, + } + } + } + None => None, + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -1049,4 +1088,100 @@ mod tests { assert_eq!(snapshot.partitions_schema(None).unwrap(), None); } + + #[test] + fn test_field_with_name() { + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b", DataType::INTEGER, true), + ]); + let field = get_stats_field(&schema, "b").unwrap(); + assert_eq!(*field, StructField::new("b", DataType::INTEGER, true)); + } + + #[test] + fn test_field_with_name_escaped() { + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b.b", DataType::INTEGER, true), + ]); + let field = get_stats_field(&schema, "`b.b`").unwrap(); + assert_eq!(*field, StructField::new("b.b", DataType::INTEGER, true)); + } + + #[test] + fn test_field_does_not_exist() { + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b", DataType::INTEGER, true), + ]); + let field = get_stats_field(&schema, "c"); + assert!(field.is_none()); + } + + #[test] + fn test_field_part_is_not_struct() { + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b", DataType::INTEGER, true), + ]); + let field = get_stats_field(&schema, "b.c"); + assert!(field.is_none()); + } + + #[test] + fn test_field_name_does_not_parse() { + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b", DataType::INTEGER, true), + ]); + let field = get_stats_field(&schema, "b."); + assert!(field.is_none()); + } + + #[test] + fn test_field_with_name_nested() { + let nested = StructType::new(vec![StructField::new( + "nested_struct", + DataType::BOOLEAN, + true, + )]); + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b", DataType::Struct(Box::new(nested)), true), + ]); + + let field = get_stats_field(&schema, "b.nested_struct").unwrap(); + + assert_eq!( + *field, + StructField::new("nested_struct", DataType::BOOLEAN, true) + ); + } + + #[test] + fn test_field_with_last_name_nested_backticks() { + let nested = StructType::new(vec![StructField::new("pr!me", DataType::BOOLEAN, true)]); + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b", DataType::Struct(Box::new(nested)), true), + ]); + + let field = get_stats_field(&schema, "b.`pr!me`").unwrap(); + + assert_eq!(*field, StructField::new("pr!me", DataType::BOOLEAN, true)); + } + + #[test] + fn test_field_with_name_nested_backticks() { + let nested = StructType::new(vec![StructField::new("pr", DataType::BOOLEAN, true)]); + let schema = StructType::new(vec![ + StructField::new("a", DataType::STRING, true), + StructField::new("b&b", DataType::Struct(Box::new(nested)), true), + ]); + + let field = get_stats_field(&schema, "`b&b`.pr").unwrap(); + + assert_eq!(*field, StructField::new("pr", DataType::BOOLEAN, true)); + } }