Skip to content

Commit

Permalink
feat: Add GetStructField expression (apache#731)
Browse files Browse the repository at this point in the history
* Add GetStructField support

* Add custom types to CometBatchScanExec

* Remove test explain

* Rust fmt

* Fix struct type support checks

* Support converting StructArray to native

* fix style

* Attempt to fix scalar subquery issue

* Fix other unit test

* Cleanup

* Default query plan supporting complex type to false

* Migrate struct expressions to spark-expr

* Update shouldApplyRowToColumnar comment

* Add nulls to test

* Rename to allowStruct

* Add DataTypeSupport trait

* Fix parquet datatype test
  • Loading branch information
Kimahriman authored Aug 3, 2024
1 parent ab2dcaa commit 5b5142b
Show file tree
Hide file tree
Showing 15 changed files with 313 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.collection.JavaConverters._
import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.complex.StructVector
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types._
Expand Down Expand Up @@ -258,7 +259,7 @@ object Utils {
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector) =>
_: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector) =>
v.asInstanceOf[FieldVector]
case _ =>
throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}")
Expand Down
1 change: 0 additions & 1 deletion native/core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ pub mod bloom_filter_might_contain;
pub mod comet_scalar_funcs;
pub mod correlation;
pub mod covariance;
pub mod create_named_struct;
pub mod negative;
pub mod stats;
pub mod stddev;
Expand Down
9 changes: 7 additions & 2 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ use crate::{
},
};

use super::expressions::{create_named_struct::CreateNamedStruct, EvalMode};
use super::expressions::EvalMode;
use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun;
use datafusion_comet_proto::{
spark_expression::{
Expand All @@ -109,7 +109,8 @@ use datafusion_comet_proto::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
};
use datafusion_comet_spark_expr::{
Cast, DateTruncExpr, HourExpr, IfExpr, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr,
Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, MinuteExpr, RLike,
SecondExpr, TimestampTruncExpr,
};

// For clippy error on type_complexity.
Expand Down Expand Up @@ -619,6 +620,10 @@ impl PhysicalPlanner {
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(CreateNamedStruct::new(values, data_type)))
}
ExprStruct::GetStructField(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?;
Ok(Arc::new(GetStructField::new(child, expr.ordinal as usize)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ message Expr {
UnboundReference unbound = 51;
BloomFilterMightContain bloom_filter_might_contain = 52;
CreateNamedStruct create_named_struct = 53;
GetStructField get_struct_field = 54;
}
}

Expand Down Expand Up @@ -492,6 +493,11 @@ message CreateNamedStruct {
DataType datatype = 2;
}

message GetStructField {
Expr child = 1;
int32 ordinal = 2;
}

enum SortDirection {
Ascending = 0;
Descending = 1;
Expand Down
2 changes: 2 additions & 0 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod kernels;
mod regexp;
pub mod scalar_funcs;
pub mod spark_hash;
mod structs;
mod temporal;
pub mod timezone;
pub mod utils;
Expand All @@ -32,6 +33,7 @@ pub use cast::{spark_cast, Cast};
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use regexp::RLike;
pub use structs::{CreateNamedStruct, GetStructField};
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};

/// Spark supports three evaluation modes when evaluating expressions, which affect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use arrow::compute::take;
use arrow::record_batch::RecordBatch;
use arrow_array::types::Int32Type;
use arrow_array::{Array, DictionaryArray, StructArray};
use arrow_schema::{DataType, Schema};
use arrow_schema::{DataType, Field, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use std::{
any::Any,
Expand All @@ -30,7 +30,7 @@ use std::{
sync::Arc,
};

use crate::execution::datafusion::expressions::utils::down_cast_any_ref;
use crate::utils::down_cast_any_ref;

#[derive(Debug, Hash)]
pub struct CreateNamedStruct {
Expand Down Expand Up @@ -142,6 +142,106 @@ impl PartialEq<dyn Any> for CreateNamedStruct {
}
}

#[derive(Debug, Hash)]
pub struct GetStructField {
child: Arc<dyn PhysicalExpr>,
ordinal: usize,
}

impl GetStructField {
pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
Self { child, ordinal }
}

fn child_field(&self, input_schema: &Schema) -> DataFusionResult<Arc<Field>> {
match self.child.data_type(input_schema)? {
DataType::Struct(fields) => Ok(fields[self.ordinal].clone()),
data_type => Err(DataFusionError::Plan(format!(
"Expect struct field, got {:?}",
data_type
))),
}
}
}

impl PhysicalExpr for GetStructField {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
Ok(self.child_field(input_schema)?.data_type().clone())
}

fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
Ok(self.child_field(input_schema)?.is_nullable())
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let child_value = self.child.evaluate(batch)?;

match child_value {
ColumnarValue::Array(array) => {
let struct_array = array
.as_any()
.downcast_ref::<StructArray>()
.expect("A struct is expected");

Ok(ColumnarValue::Array(
struct_array.column(self.ordinal).clone(),
))
}
ColumnarValue::Scalar(ScalarValue::Struct(struct_array)) => Ok(ColumnarValue::Array(
struct_array.column(self.ordinal).clone(),
)),
value => Err(DataFusionError::Execution(format!(
"Expected a struct array, got {:?}",
value
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.child]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(GetStructField::new(
children[0].clone(),
self.ordinal,
)))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.child.hash(&mut s);
self.ordinal.hash(&mut s);
self.hash(&mut s);
}
}

impl Display for GetStructField {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"GetStructField [child: {:?}, ordinal: {:?}]",
self.child, self.ordinal
)
}
}

impl PartialEq<dyn Any> for GetStructField {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal))
.unwrap_or(false)
}
}

#[cfg(test)]
mod test {
use super::CreateNamedStruct;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf._
import org.apache.comet.CometExplainInfo.getActualPlan
import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometBroadcastNotEnabledReason, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, isSpark40Plus, shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometBroadcastNotEnabledReason, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSpark34Plus, isSpark40Plus, shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
Expand Down Expand Up @@ -95,8 +94,10 @@ class CometSparkSessionExtensions
// data source V2
case scanExec: BatchScanExec
if scanExec.scan.isInstanceOf[ParquetScan] &&
isSchemaSupported(scanExec.scan.asInstanceOf[ParquetScan].readDataSchema) &&
isSchemaSupported(scanExec.scan.asInstanceOf[ParquetScan].readPartitionSchema) &&
CometBatchScanExec.isSchemaSupported(
scanExec.scan.asInstanceOf[ParquetScan].readDataSchema) &&
CometBatchScanExec.isSchemaSupported(
scanExec.scan.asInstanceOf[ParquetScan].readPartitionSchema) &&
// Comet does not support pushedAggregate
scanExec.scan.asInstanceOf[ParquetScan].pushedAggregate.isEmpty =>
val cometScan = CometParquetScan(scanExec.scan.asInstanceOf[ParquetScan])
Expand All @@ -110,11 +111,11 @@ class CometSparkSessionExtensions
case scanExec: BatchScanExec if scanExec.scan.isInstanceOf[ParquetScan] =>
val requiredSchema = scanExec.scan.asInstanceOf[ParquetScan].readDataSchema
val info1 = createMessage(
!isSchemaSupported(requiredSchema),
!CometBatchScanExec.isSchemaSupported(requiredSchema),
s"Schema $requiredSchema is not supported")
val readPartitionSchema = scanExec.scan.asInstanceOf[ParquetScan].readPartitionSchema
val info2 = createMessage(
!isSchemaSupported(readPartitionSchema),
!CometBatchScanExec.isSchemaSupported(readPartitionSchema),
s"Partition schema $readPartitionSchema is not supported")
// Comet does not support pushedAggregate
val info3 = createMessage(
Expand All @@ -129,7 +130,7 @@ class CometSparkSessionExtensions
// Iceberg scan, supported cases
case s: SupportsComet
if s.isCometEnabled &&
isSchemaSupported(scanExec.scan.readSchema()) =>
CometBatchScanExec.isSchemaSupported(scanExec.scan.readSchema()) =>
logInfo(s"Comet extension enabled for ${scanExec.scan.getClass.getSimpleName}")
// When reading from Iceberg, we automatically enable type promotion
SQLConf.get.setConfString(COMET_SCHEMA_EVOLUTION_ENABLED.key, "true")
Expand All @@ -144,7 +145,7 @@ class CometSparkSessionExtensions
"Comet extension is not enabled for " +
s"${scanExec.scan.getClass.getSimpleName}: not enabled on data source side")
val info2 = createMessage(
!isSchemaSupported(scanExec.scan.readSchema()),
!CometBatchScanExec.isSchemaSupported(scanExec.scan.readSchema()),
"Comet extension is not enabled for " +
s"${scanExec.scan.getClass.getSimpleName}: Schema not supported")
withInfos(scanExec, Seq(info1, info2).flatten.toSet)
Expand All @@ -166,7 +167,9 @@ class CometSparkSessionExtensions
_,
_,
_,
_) if isSchemaSupported(requiredSchema) && isSchemaSupported(partitionSchema) =>
_)
if CometScanExec.isSchemaSupported(requiredSchema)
&& CometScanExec.isSchemaSupported(partitionSchema) =>
logInfo("Comet extension enabled for v1 Scan")
CometScanExec(scanExec, session)

Expand All @@ -182,10 +185,10 @@ class CometSparkSessionExtensions
_,
_) =>
val info1 = createMessage(
!isSchemaSupported(requiredSchema),
!CometScanExec.isSchemaSupported(requiredSchema),
s"Schema $requiredSchema is not supported")
val info2 = createMessage(
!isSchemaSupported(partitionSchema),
!CometScanExec.isSchemaSupported(partitionSchema),
s"Partition schema $partitionSchema is not supported")
withInfo(scanExec, Seq(info1, info2).flatten.mkString(","))
scanExec
Expand Down Expand Up @@ -1109,28 +1112,17 @@ object CometSparkSessionExtensions extends Logging {
COMET_EXEC_ALL_OPERATOR_ENABLED.get(conf)
}

private[comet] def isSchemaSupported(schema: StructType): Boolean =
schema.map(_.dataType).forall(isTypeSupported)

private[comet] def isTypeSupported(dt: DataType): Boolean = dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
BinaryType | StringType | _: DecimalType | DateType | TimestampType =>
true
case t: DataType if t.typeName == "timestamp_ntz" => true
case dt =>
logInfo(s"Comet extension is disabled because data type $dt is not supported")
false
}

def isCometScan(op: SparkPlan): Boolean = {
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
}

private def shouldApplyRowToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
// Only consider converting leaf nodes to columnar currently, so that all the following
// operators can have a chance to be converted to columnar.
// operators can have a chance to be converted to columnar. Leaf operators that output
// columnar batches, such as Spark's vectorized readers, will also be converted to native
// comet batches.
// TODO: consider converting other intermediate operators to columnar.
op.isInstanceOf[LeafExecNode] && !op.supportsColumnar && isSchemaSupported(op.schema) &&
op.isInstanceOf[LeafExecNode] && CometRowToColumnarExec.isSchemaSupported(op.schema) &&
COMET_ROW_TO_COLUMNAR_ENABLED.get(conf) && {
val simpleClassName = Utils.getSimpleName(op.getClass)
val nodeName = simpleClassName.replaceAll("Exec$", "")
Expand Down
Loading

0 comments on commit 5b5142b

Please sign in to comment.