Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support GetArrayStructFields expression #993

Merged
merged 8 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ use datafusion_comet_proto::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
};
use datafusion_comet_spark_expr::{
Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, ListExtract,
MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr,
ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
};
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
Expand Down Expand Up @@ -680,6 +680,15 @@ impl PhysicalPlanner {
expr.fail_on_error,
)))
}
ExprStruct::GetArrayStructFields(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;

Ok(Arc::new(GetArrayStructFields::new(
child,
expr.ordinal as usize,
)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ message Expr {
GetStructField get_struct_field = 54;
ToJson to_json = 55;
ListExtract list_extract = 56;
GetArrayStructFields get_array_struct_fields = 57;
}
}

Expand Down Expand Up @@ -517,6 +518,11 @@ message ListExtract {
bool fail_on_error = 5;
}

message GetArrayStructFields {
Expr child = 1;
int32 ordinal = 2;
}

enum SortDirection {
Ascending = 0;
Descending = 1;
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod xxhash64;
pub use cast::{spark_cast, Cast};
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::ListExtract;
pub use list::{GetArrayStructFields, ListExtract};
pub use regexp::RLike;
pub use structs::{CreateNamedStruct, GetStructField};
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};
Expand Down
140 changes: 139 additions & 1 deletion native/spark-expr/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch};
use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, StructArray};
use arrow_schema::{DataType, FieldRef, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
Expand Down Expand Up @@ -275,6 +275,144 @@ impl PartialEq<dyn Any> for ListExtract {
}
}

#[derive(Debug, Hash)]
pub struct GetArrayStructFields {
child: Arc<dyn PhysicalExpr>,
ordinal: usize,
}

impl GetArrayStructFields {
pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
Self { child, ordinal }
}

fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
match self.child.data_type(input_schema)? {
DataType::List(field) | DataType::LargeList(field) => Ok(field),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected data type in GetArrayStructFields: {:?}",
data_type
))),
}
}

fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
match self.list_field(input_schema)?.data_type() {
DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected data type in GetArrayStructFields: {:?}",
data_type
))),
}
}
}

impl PhysicalExpr for GetArrayStructFields {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
let struct_field = self.child_field(input_schema)?;
match self.child.data_type(input_schema)? {
DataType::List(_) => Ok(DataType::List(struct_field)),
DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
data_type => Err(DataFusionError::Internal(format!(
"Unexpected data type in GetArrayStructFields: {:?}",
data_type
))),
}
}

fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
Ok(self.list_field(input_schema)?.is_nullable()
|| self.child_field(input_schema)?.is_nullable())
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;

match child_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&child_value)?;

get_array_struct_fields(list_array, self.ordinal)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&child_value)?;

get_array_struct_fields(list_array, self.ordinal)
}
data_type => Err(DataFusionError::Internal(format!(
"Unexpected child type for ListExtract: {:?}",
data_type
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.child]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
match children.len() {
1 => Ok(Arc::new(GetArrayStructFields::new(
Arc::clone(&children[0]),
self.ordinal,
))),
_ => internal_err!("GetArrayStructFields should have exactly one child"),
}
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.child.hash(&mut s);
self.ordinal.hash(&mut s);
self.hash(&mut s);
}
}

fn get_array_struct_fields<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
ordinal: usize,
) -> DataFusionResult<ColumnarValue> {
let values = list_array
.values()
.as_any()
.downcast_ref::<StructArray>()
.expect("A struct is expected");

let column = Arc::clone(values.column(ordinal));
let field = Arc::clone(&values.fields()[ordinal]);

let offsets = list_array.offsets();
let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned());

Ok(ColumnarValue::Array(Arc::new(array)))
}

impl Display for GetArrayStructFields {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"GetArrayStructFields [child: {:?}, ordinal: {:?}]",
self.child, self.ordinal
)
}
}

impl PartialEq<dyn Any> for GetArrayStructFields {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal))
.unwrap_or(false)
}
}

#[cfg(test)]
mod test {
use crate::list::{list_extract, zero_based_index};
Expand Down
19 changes: 19 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case GetArrayStructFields(child, _, ordinal, _, _) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To help other reviewers, here is Spark's definition of GetArrayStructFields:

case class GetArrayStructFields(
    child: Expression,
    field: StructField,
    ordinal: Int,
    numFields: Int,
    containsNull: Boolean)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah those other three fields aren't relevant for actual execution

val childExpr = exprToProto(child, inputs, binding)

if (childExpr.isDefined) {
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinal)

Some(
ExprOuterClass.Expr
.newBuilder()
.setGetArrayStructFields(arrayStructFieldsBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
None
}

case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
22 changes: 22 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2271,4 +2271,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("GetArrayStructFields") {
Seq(true, false).foreach { dictionaryEnabled =>
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
val df = spark.read
.parquet(path.toString)
.select(
array(struct(col("_2"), col("_3"), col("_4"), col("_8")), lit(null)).alias("arr"))
checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4"))

val complex = spark.read
.parquet(path.toString)
.select(array(struct(struct(col("_4"), col("_8")).alias("nested"))).alias("arr"))

checkSparkAnswerAndOperator(complex.select(col("arr.nested._4")))
}
}
}
}
}
Loading