diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index c3aad61e07..122a24ed3d 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -19,6 +19,7 @@ use super::expressions::EvalMode; use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; +use crate::execution::operators::CopyMode; use crate::{ errors::ExpressionError, execution::{ @@ -859,7 +860,11 @@ impl PhysicalPlanner { let fetch = sort.fetch.map(|num| num as usize); - let copy_exec = Arc::new(CopyExec::new(child)); + let copy_exec = if can_reuse_input_batch(&child) { + Arc::new(CopyExec::new(child, CopyMode::UnpackOrDeepCopy)) + } else { + Arc::new(CopyExec::new(child, CopyMode::UnpackOrClone)) + }; Ok(( scans, @@ -949,8 +954,8 @@ impl PhysicalPlanner { // the data corruption. Note that we only need to copy the input batch // if the child operator is `ScanExec`, because other operators after `ScanExec` // will create new arrays for the output batch. - let child = if child.as_any().is::() { - Arc::new(CopyExec::new(child)) + let child = if can_reuse_input_batch(&child) { + Arc::new(CopyExec::new(child, CopyMode::UnpackOrDeepCopy)) } else { child }; @@ -1205,15 +1210,15 @@ impl PhysicalPlanner { // to copy the input batch to avoid the data corruption from reusing the input // batch. let left = if can_reuse_input_batch(&left) { - Arc::new(CopyExec::new(left)) + Arc::new(CopyExec::new(left, CopyMode::UnpackOrDeepCopy)) } else { - left + Arc::new(CopyExec::new(left, CopyMode::UnpackOrClone)) }; let right = if can_reuse_input_batch(&right) { - Arc::new(CopyExec::new(right)) + Arc::new(CopyExec::new(right, CopyMode::UnpackOrDeepCopy)) } else { - right + Arc::new(CopyExec::new(right, CopyMode::UnpackOrClone)) }; Ok(( @@ -1775,10 +1780,14 @@ impl From for DataFusionError { /// modification. This is used to determine if we need to copy the input batch to avoid /// data corruption from reusing the input batch. fn can_reuse_input_batch(op: &Arc) -> bool { - op.as_any().is::() + if op.as_any().is::() || op.as_any().is::() - || op.as_any().is::() || op.as_any().is::() + { + can_reuse_input_batch(op.children()[0]) + } else { + op.as_any().is::() + } } /// Collects the indices of the columns in the input schema that are used in the expression diff --git a/native/core/src/execution/operators/copy.rs b/native/core/src/execution/operators/copy.rs index b5b1491ed1..0705a3b7c8 100644 --- a/native/core/src/execution/operators/copy.rs +++ b/native/core/src/execution/operators/copy.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use arrow::compute::{cast_with_options, CastOptions}; +use futures::{Stream, StreamExt}; use std::{ any::Any, pin::Pin, @@ -22,17 +24,16 @@ use std::{ task::{Context, Poll}, }; -use futures::{Stream, StreamExt}; - -use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; -use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; +use arrow_array::{ + downcast_dictionary_array, make_array, Array, ArrayRef, RecordBatch, RecordBatchOptions, +}; +use arrow_data::transform::MutableArrayData; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use datafusion::{execution::TaskContext, physical_expr::*, physical_plan::*}; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; -use super::copy_or_cast_array; - /// An utility execution node which makes deep copies of input batches. /// /// In certain scenarios like sort, DF execution nodes only make shallow copy of input batches. @@ -44,10 +45,20 @@ pub struct CopyExec { schema: SchemaRef, cache: PlanProperties, metrics: ExecutionPlanMetricsSet, + mode: CopyMode, +} + +#[derive(Debug, PartialEq, Clone)] +pub enum CopyMode { + UnpackOrDeepCopy, + UnpackOrClone, } impl CopyExec { - pub fn new(input: Arc) -> Self { + pub fn new(input: Arc, mode: CopyMode) -> Self { + // change schema to remove dictionary types because CopyExec always unpacks + // dictionaries + let fields: Vec = input .schema() .fields @@ -73,6 +84,7 @@ impl CopyExec { schema, cache, metrics: ExecutionPlanMetricsSet::default(), + mode, } } } @@ -81,7 +93,7 @@ impl DisplayAs for CopyExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "CopyExec") + write!(f, "CopyExec [{:?}]", self.mode) } } } @@ -111,6 +123,7 @@ impl ExecutionPlan for CopyExec { schema: self.schema.clone(), cache: self.cache.clone(), metrics: self.metrics.clone(), + mode: self.mode.clone(), })) } @@ -125,6 +138,7 @@ impl ExecutionPlan for CopyExec { self.schema(), child_stream, partition, + self.mode.clone(), ))) } @@ -149,6 +163,7 @@ struct CopyStream { schema: SchemaRef, child_stream: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + mode: CopyMode, } impl CopyStream { @@ -157,11 +172,13 @@ impl CopyStream { schema: SchemaRef, child_stream: SendableRecordBatchStream, partition: usize, + mode: CopyMode, ) -> Self { Self { schema, child_stream, baseline_metrics: BaselineMetrics::new(&exec.metrics, partition), + mode, } } @@ -172,7 +189,7 @@ impl CopyStream { let vectors = batch .columns() .iter() - .map(|v| copy_or_cast_array(v)) + .map(|v| copy_or_unpack_array(v, &self.mode)) .collect::, _>>()?; let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); @@ -200,3 +217,56 @@ impl RecordBatchStream for CopyStream { self.schema.clone() } } + +/// Copy an Arrow Array +fn copy_array(array: &dyn Array) -> ArrayRef { + let capacity = array.len(); + let data = array.to_data(); + + let mut mutable = MutableArrayData::new(vec![&data], false, capacity); + + mutable.extend(0, 0, capacity); + + if matches!(array.data_type(), DataType::Dictionary(_, _)) { + let copied_dict = make_array(mutable.freeze()); + let ref_copied_dict = &copied_dict; + + downcast_dictionary_array!( + ref_copied_dict => { + // Copying dictionary value array + let values = ref_copied_dict.values(); + let data = values.to_data(); + + let mut mutable = MutableArrayData::new(vec![&data], false, values.len()); + mutable.extend(0, 0, values.len()); + + let copied_dict = ref_copied_dict.with_values(make_array(mutable.freeze())); + Arc::new(copied_dict) + } + t => unreachable!("Should not reach here: {}", t) + ) + } else { + make_array(mutable.freeze()) + } +} + +/// Copy an Arrow Array or cast to primitive type if it is a dictionary array. +/// This is used for `CopyExec` to copy/cast the input array. If the input array +/// is a dictionary array, we will cast the dictionary array to primitive type +/// (i.e., unpack the dictionary array) and copy the primitive array. If the input +/// array is a primitive array, we simply copy the array. +fn copy_or_unpack_array(array: &Arc, mode: &CopyMode) -> Result { + match array.data_type() { + DataType::Dictionary(_, value_type) => { + let options = CastOptions::default(); + cast_with_options(array, value_type.as_ref(), &options) + } + _ => { + if mode == &CopyMode::UnpackOrDeepCopy { + Ok(copy_array(array)) + } else { + Ok(Arc::clone(array)) + } + } + } +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index d0cc7ac681..09e05ef262 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -17,22 +17,15 @@ //! Operators -use arrow::{ - array::{make_array, Array, ArrayRef, MutableArrayData}, - datatypes::DataType, - downcast_dictionary_array, -}; +use std::fmt::Debug; -use arrow::compute::{cast_with_options, CastOptions}; -use arrow_schema::ArrowError; use jni::objects::GlobalRef; -use std::{fmt::Debug, sync::Arc}; -mod scan; +pub use copy::*; pub use scan::*; mod copy; -pub use copy::*; +mod scan; /// Error returned during executing operators. #[derive(thiserror::Error, Debug)] @@ -61,52 +54,3 @@ pub enum ExecutionError { throwable: GlobalRef, }, } - -/// Copy an Arrow Array -pub fn copy_array(array: &dyn Array) -> ArrayRef { - let capacity = array.len(); - let data = array.to_data(); - - let mut mutable = MutableArrayData::new(vec![&data], false, capacity); - - mutable.extend(0, 0, capacity); - - if matches!(array.data_type(), DataType::Dictionary(_, _)) { - let copied_dict = make_array(mutable.freeze()); - let ref_copied_dict = &copied_dict; - - downcast_dictionary_array!( - ref_copied_dict => { - // Copying dictionary value array - let values = ref_copied_dict.values(); - let data = values.to_data(); - - let mut mutable = MutableArrayData::new(vec![&data], false, values.len()); - mutable.extend(0, 0, values.len()); - - let copied_dict = ref_copied_dict.with_values(make_array(mutable.freeze())); - Arc::new(copied_dict) - } - t => unreachable!("Should not reach here: {}", t) - ) - } else { - make_array(mutable.freeze()) - } -} - -/// Copy an Arrow Array or cast to primitive type if it is a dictionary array. -/// This is used for `CopyExec` to copy/cast the input array. If the input array -/// is a dictionary array, we will cast the dictionary array to primitive type -/// (i.e., unpack the dictionary array) and copy the primitive array. If the input -/// array is a primitive array, we simply copy the array. -pub fn copy_or_cast_array(array: &dyn Array) -> Result { - match array.data_type() { - DataType::Dictionary(_, value_type) => { - let options = CastOptions::default(); - let casted = cast_with_options(array, value_type.as_ref(), &options); - - casted.and_then(|a| copy_or_cast_array(a.as_ref())) - } - _ => Ok(copy_array(array)), - } -}