diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 15de7c9ad..d63fd7078 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -96,8 +96,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, ListExtract, - MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, + Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr, + ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ @@ -680,6 +680,15 @@ impl PhysicalPlanner { expr.fail_on_error, ))) } + ExprStruct::GetArrayStructFields(expr) => { + let child = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; + + Ok(Arc::new(GetArrayStructFields::new( + child, + expr.ordinal as usize, + ))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 88940f386..1a3e3c9fc 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -81,6 +81,7 @@ message Expr { GetStructField get_struct_field = 54; ToJson to_json = 55; ListExtract list_extract = 56; + GetArrayStructFields get_array_struct_fields = 57; } } @@ -517,6 +518,11 @@ message ListExtract { bool fail_on_error = 5; } +message GetArrayStructFields { + Expr child = 1; + int32 ordinal = 2; +} + enum SortDirection { Ascending = 0; Descending = 1; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index c4b1c99ba..cc22dfcbc 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -38,7 +38,7 @@ mod xxhash64; pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; -pub use list::ListExtract; +pub use list::{GetArrayStructFields, ListExtract}; pub use regexp::RLike; pub use structs::{CreateNamedStruct, GetStructField}; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 0b85a8424..a376198db 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -16,7 +16,7 @@ // under the License. use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; -use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; use arrow_schema::{DataType, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; @@ -275,6 +275,144 @@ impl PartialEq for ListExtract { } } +#[derive(Debug, Hash)] +pub struct GetArrayStructFields { + child: Arc, + ordinal: usize, +} + +impl GetArrayStructFields { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn list_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.list_field(input_schema)?.data_type() { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetArrayStructFields { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + let struct_field = self.child_field(input_schema)?; + match self.child.data_type(input_schema)? { + DataType::List(_) => Ok(DataType::List(struct_field)), + DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.list_field(input_schema)?.is_nullable() + || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(GetArrayStructFields::new( + Arc::clone(&children[0]), + self.ordinal, + ))), + _ => internal_err!("GetArrayStructFields should have exactly one child"), + } + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.ordinal.hash(&mut s); + self.hash(&mut s); + } +} + +fn get_array_struct_fields( + list_array: &GenericListArray, + ordinal: usize, +) -> DataFusionResult { + let values = list_array + .values() + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + let column = Arc::clone(values.column(ordinal)); + let field = Arc::clone(&values.fields()[ordinal]); + + let offsets = list_array.offsets(); + let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::new(array))) +} + +impl Display for GetArrayStructFields { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetArrayStructFields [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} + +impl PartialEq for GetArrayStructFields { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal)) + .unwrap_or(false) + } +} + #[cfg(test)] mod test { use crate::list::{list_extract, zero_based_index}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 51b32b7df..02b845e7c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2542,6 +2542,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } + case GetArrayStructFields(child, _, ordinal, _, _) => + val childExpr = exprToProto(child, inputs, binding) + + if (childExpr.isDefined) { + val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinal) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setGetArrayStructFields(arrayStructFieldsBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayStructFields", child) + None + } + case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 16bc15b84..da22df402 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2271,4 +2271,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("GetArrayStructFields") { + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + val df = spark.read + .parquet(path.toString) + .select( + array(struct(col("_2"), col("_3"), col("_4"), col("_8")), lit(null)).alias("arr")) + checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4")) + + val complex = spark.read + .parquet(path.toString) + .select(array(struct(struct(col("_4"), col("_8")).alias("nested"))).alias("arr")) + + checkSparkAnswerAndOperator(complex.select(col("arr.nested._4"))) + } + } + } + } }