From 7a972278655a114df63fe1e5c237f47f9451b1f3 Mon Sep 17 00:00:00 2001 From: Marijn Valk Date: Sat, 28 Jan 2023 18:47:31 +0100 Subject: [PATCH] add test for left_larger_than_right (#1110) Signed-off-by: Marijn Valk # Description Adds a test for the `left_larger_than_right` function and rewrites the function match expression to match on both the `left` and `right` argument # Related Issue(s) # Documentation --------- Signed-off-by: Marijn Valk --- rust/src/delta_datafusion.rs | 90 +++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 33 deletions(-) diff --git a/rust/src/delta_datafusion.rs b/rust/src/delta_datafusion.rs index dfd7472a3d..29ec01d121 100644 --- a/rust/src/delta_datafusion.rs +++ b/rust/src/delta_datafusion.rs @@ -686,39 +686,15 @@ fn correct_scalar_value_type(value: ScalarValue, field_dt: &ArrowDataType) -> Op } fn left_larger_than_right(left: ScalarValue, right: ScalarValue) -> Option { - match left { - ScalarValue::Float64(Some(v)) => { - let f_right = f64::try_from(right).ok()?; - Some(v > f_right) - } - ScalarValue::Float32(Some(v)) => { - let f_right = f32::try_from(right).ok()?; - Some(v > f_right) - } - ScalarValue::Int8(Some(v)) => { - let i_right = i8::try_from(right).ok()?; - Some(v > i_right) - } - ScalarValue::Int16(Some(v)) => { - let i_right = i16::try_from(right).ok()?; - Some(v > i_right) - } - ScalarValue::Int32(Some(v)) => { - let i_right = i32::try_from(right).ok()?; - Some(v > i_right) - } - ScalarValue::Int64(Some(v)) => { - let i_right = i64::try_from(right).ok()?; - Some(v > i_right) - } - ScalarValue::Boolean(Some(v)) => { - let b_right = bool::try_from(right).ok()?; - Some(v & !b_right) - } - ScalarValue::Utf8(Some(v)) => match right { - ScalarValue::Utf8(Some(s_right)) => Some(v > s_right), - _ => None, - }, + match (&left, &right) { + (ScalarValue::Float64(Some(l)), ScalarValue::Float64(Some(r))) => Some(l > r), + (ScalarValue::Float32(Some(l)), ScalarValue::Float32(Some(r))) => Some(l > r), + (ScalarValue::Int8(Some(l)), ScalarValue::Int8(Some(r))) => Some(l > r), + (ScalarValue::Int16(Some(l)), ScalarValue::Int16(Some(r))) => Some(l > r), + (ScalarValue::Int32(Some(l)), ScalarValue::Int32(Some(r))) => Some(l > r), + (ScalarValue::Int64(Some(l)), ScalarValue::Int64(Some(r))) => Some(l > r), + (ScalarValue::Utf8(Some(l)), ScalarValue::Utf8(Some(r))) => Some(l > r), + (ScalarValue::Boolean(Some(l)), ScalarValue::Boolean(Some(r))) => Some(l & !r), _ => { log::error!( "Scalar value comparison unimplemented for {:?} and {:?}", @@ -1004,6 +980,54 @@ mod tests { } } + #[test] + fn test_left_larger_than_right() { + let correct_reference_pairs = vec![ + ( + ScalarValue::Float64(Some(1.0)), + ScalarValue::Float64(Some(2.0)), + ), + ( + ScalarValue::Float32(Some(1.0)), + ScalarValue::Float32(Some(2.0)), + ), + (ScalarValue::Int8(Some(1)), ScalarValue::Int8(Some(2))), + (ScalarValue::Int16(Some(1)), ScalarValue::Int16(Some(2))), + (ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(2))), + (ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))), + ( + ScalarValue::Boolean(Some(false)), + ScalarValue::Boolean(Some(true)), + ), + ( + ScalarValue::Utf8(Some(String::from("1"))), + ScalarValue::Utf8(Some(String::from("2"))), + ), + ]; + for (smaller_val, larger_val) in correct_reference_pairs { + assert_eq!( + left_larger_than_right(smaller_val.clone(), larger_val.clone()), + Some(false) + ); + assert_eq!(left_larger_than_right(larger_val, smaller_val), Some(true)); + } + + let incorrect_reference_pairs = vec![ + ( + ScalarValue::Float64(Some(1.0)), + ScalarValue::Float32(Some(2.0)), + ), + (ScalarValue::Int32(Some(1)), ScalarValue::Float32(Some(2.0))), + ( + ScalarValue::Boolean(Some(true)), + ScalarValue::Float32(Some(2.0)), + ), + ]; + for (left, right) in incorrect_reference_pairs { + assert_eq!(left_larger_than_right(left, right), None); + } + } + #[test] fn test_partitioned_file_from_action() { let mut partition_values = std::collections::HashMap::new();