Skip to content

Commit

Permalink
fix: more tests and PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
alexwilcoxson-rel committed Aug 14, 2024
1 parent 2fef2fe commit 36fece4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 24 deletions.
3 changes: 1 addition & 2 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ tracing = { workspace = true }
rand = "0.8"
z85 = "3.0.5"
maplit = "1"
sqlparser = { version = "0.49" }

# Unity
reqwest = { version = "0.11.18", default-features = false, features = [
"rustls-tls",
"json",
], optional = true }
sqlparser = { version = "0.49", optional = true }

[dev-dependencies]
criterion = "0.5"
Expand Down Expand Up @@ -130,7 +130,6 @@ datafusion = [
"datafusion-sql",
"datafusion-functions",
"datafusion-functions-aggregate",
"sqlparser",
]
datafusion-ext = ["datafusion"]
json = ["parquet/json"]
Expand Down
85 changes: 63 additions & 22 deletions crates/core/src/kernel/snapshot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,33 +763,34 @@ mod datafusion {
}
}

/// Checks if provided string is enclosed in char_to_remove and extracts the enclosed string.
/// If the string is not enclosed, it returns the original string.
fn extract_enclosed_string(string_to_extract: &str, enclosing_char: char) -> &str {
if string_to_extract.len() > 1
&& string_to_extract.starts_with(enclosing_char)
&& string_to_extract.ends_with(enclosing_char)
{
&string_to_extract[1..string_to_extract.len() - 1]
} else {
string_to_extract
}
}

/// Retrieves a specific field from the schema based on the provided field name.
/// It handles cases where the field name is nested or enclosed in backticks.
fn get_stats_field<'a>(schema: &'a StructType, stats_field_name: &str) -> Option<&'a StructField> {
match stats_field_name.split_once('.') {
Some((parent, children)) => {
let parent_field = schema.fields.get(extract_enclosed_string(parent, '`'))?;
match parent_field.data_type() {
DataType::Struct(inner) => get_stats_field(inner.as_ref(), children),
_ => None,
let dialect = sqlparser::dialect::GenericDialect {};
match sqlparser::parser::Parser::new(&dialect).try_with_sql(stats_field_name) {
Ok(mut parser) => match parser.parse_multipart_identifier() {
Ok(parts) => find_nested_field(schema, &parts, 0),
Err(_) => schema.field(stats_field_name),
}
Err(_) => schema.field(stats_field_name),
}
}

fn find_nested_field<'a>(schema: &'a StructType, parts: &[sqlparser::ast::Ident], part_idx: usize) -> Option<&'a StructField> {
let part_name = &parts[part_idx].value;
match schema.field(part_name) {
Some(field) => {
if part_idx == parts.len() - 1 {
Some(field)
} else {
match field.data_type() {
DataType::Struct(struct_schema) => find_nested_field(struct_schema, parts, part_idx + 1),
// Any part before the end must be a struct
_ => None,
}
}
}
None => schema
.fields
.get(extract_enclosed_string(stats_field_name, '`')),
None => None,
}
}

Expand Down Expand Up @@ -1090,6 +1091,46 @@ mod tests {
assert_eq!(*field, StructField::new("b", DataType::INTEGER, true));
}

#[test]
fn test_field_with_name_escaped() {
let schema = StructType::new(vec![
StructField::new("a", DataType::STRING, true),
StructField::new("b.b", DataType::INTEGER, true),
]);
let field = get_stats_field(&schema, "`b.b`").unwrap();
assert_eq!(*field, StructField::new("b.b", DataType::INTEGER, true));
}

#[test]
fn test_field_does_not_exist() {
let schema = StructType::new(vec![
StructField::new("a", DataType::STRING, true),
StructField::new("b", DataType::INTEGER, true),
]);
let field = get_stats_field(&schema, "c");
assert!(field.is_none());
}

#[test]
fn test_field_part_is_not_struct() {
let schema = StructType::new(vec![
StructField::new("a", DataType::STRING, true),
StructField::new("b", DataType::INTEGER, true),
]);
let field = get_stats_field(&schema, "b.c");
assert!(field.is_none());
}

#[test]
fn test_field_name_does_not_parse() {
let schema = StructType::new(vec![
StructField::new("a", DataType::STRING, true),
StructField::new("b", DataType::INTEGER, true),
]);
let field = get_stats_field(&schema, "b.");
assert!(field.is_none());
}

#[test]
fn test_field_with_name_nested() {
let nested = StructType::new(vec![StructField::new(
Expand Down

0 comments on commit 36fece4

Please sign in to comment.