diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index 84d74c512b542..fd691973608ee 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -29,10 +29,9 @@ use crate::{ }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use futures::StreamExt; -use tokio::task::JoinSet; use super::expressions::PhysicalSortExpr; -use super::stream::RecordBatchStreamAdapter; +use super::stream::{RecordBatchReceiverStreamBuilder, RecordBatchStreamAdapter}; use super::{Distribution, SendableRecordBatchStream}; use crate::execution::context::TaskContext; @@ -121,23 +120,15 @@ impl ExecutionPlan for AnalyzeExec { // Gather futures that will run each input partition in // parallel (on a separate tokio task) using a JoinSet to // cancel outstanding futures on drop - let mut set = JoinSet::new(); let num_input_partitions = self.input.output_partitioning().partition_count(); + let mut builder = + RecordBatchReceiverStreamBuilder::new(self.schema(), num_input_partitions); for input_partition in 0..num_input_partitions { - let input_stream = self.input.execute(input_partition, context.clone()); - - set.spawn(async move { - let mut total_rows = 0; - let mut input_stream = input_stream?; - while let Some(batch) = input_stream.next().await { - let batch = batch?; - total_rows += batch.num_rows(); - } - Ok(total_rows) as Result - }); + builder.run_input(self.input.clone(), input_partition, context.clone()); } + // Create future that computes thefinal output let start = Instant::now(); let captured_input = self.input.clone(); let captured_schema = self.schema.clone(); @@ -146,18 +137,12 @@ impl ExecutionPlan for AnalyzeExec { // future that gathers the results from all the tasks in the // JoinSet that computes the overall row count and final // record batch + let mut input_stream = builder.build(); let output = async move { let mut total_rows = 0; - while let Some(res) = set.join_next().await { - // translate join errors (aka task panic's) into ExecutionErrors - match res { - Ok(row_count) => total_rows += row_count?, - Err(e) => { - return Err(DataFusionError::Execution(format!( - "Join error in AnalyzeExec: {e}" - ))) - } - } + while let Some(batch) = input_stream.next().await { + let batch = batch?; + total_rows += batch.num_rows(); } let duration = Instant::now() - start; diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index 85cc375bdb535..ffc3cb5e4da4f 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -34,7 +34,6 @@ use crate::physical_plan::{ use super::SendableRecordBatchStream; use crate::execution::context::TaskContext; -use crate::physical_plan::common::spawn_execution; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -140,14 +139,7 @@ impl ExecutionPlan for CoalescePartitionsExec { // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. for part_i in 0..input_partitions { - let sender = builder.tx(); - spawn_execution( - builder.join_set_mut(), - self.input.clone(), - sender, - part_i, - context.clone(), - ); + builder.run_input(self.input.clone(), part_i, context.clone()); } Ok(builder.build()) diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index a9c267f123a8c..90b1e8ea4ccbe 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -19,17 +19,15 @@ use super::SendableRecordBatchStream; use crate::error::{DataFusionError, Result}; -use crate::execution::context::TaskContext; use crate::execution::memory_pool::MemoryReservation; use crate::physical_plan::stream::RecordBatchReceiverStream; -use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics}; +use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{Future, StreamExt, TryStreamExt}; -use log::debug; use parking_lot::Mutex; use pin_project_lite::pin_project; use std::fs; @@ -38,7 +36,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::mpsc; -use tokio::task::{JoinHandle, JoinSet}; +use tokio::task::JoinHandle; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -96,43 +94,6 @@ fn build_file_list_recurse( Ok(()) } -/// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender -pub(crate) fn spawn_execution( - join_set: &mut JoinSet<()>, - input: Arc, - output: mpsc::Sender>, - partition: usize, - context: Arc, -) { - join_set.spawn(async move { - let mut stream = match input.execute(partition, context) { - Err(e) => { - // If send fails, plan being torn down, - // there is no place to send the error. - output.send(Err(e)).await.ok(); - debug!( - "Stopping execution: error executing input: {}", - displayable(input.as_ref()).one_line() - ); - return; - } - Ok(stream) => stream, - }; - - while let Some(item) = stream.next().await { - // If send fails, plan being torn down, - // there is no place to send the error. - if output.send(item).await.is_err() { - debug!( - "Stopping execution: output is gone, plan cancelling: {}", - displayable(input.as_ref()).one_line() - ); - return; - } - } - }); -} - /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub(crate) fn spawn_buffered( diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index a239591bd9543..0a9f53a3422bc 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -17,20 +17,26 @@ //! Stream wrappers for physical operators +use std::sync::Arc; + use crate::error::Result; +use crate::physical_plan::displayable; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::DataFusionError; +use datafusion_execution::TaskContext; use futures::stream::BoxStream; use futures::{Future, Stream, StreamExt}; +use log::debug; use pin_project_lite::pin_project; use tokio::task::{JoinHandle, JoinSet}; use tokio_stream::wrappers::ReceiverStream; use super::common::AbortOnDropSingle; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -/// Builder for [`RecordBatchReceiverStream`] -pub struct RecordBatchReceiverStreamBuilder { +/// Builder for [`RecordBatchReceiverStream`] that propagates errors +/// and panic's correctly. +pub(crate) struct RecordBatchReceiverStreamBuilder { tx: tokio::sync::mpsc::Sender>, rx: tokio::sync::mpsc::Receiver>, schema: SchemaRef, @@ -55,11 +61,6 @@ impl RecordBatchReceiverStreamBuilder { self.tx.clone() } - /// Get a handle to the `JoinSet` on which tasks are launched - pub fn join_set_mut(&mut self) -> &mut JoinSet<()> { - &mut self.join_set - } - /// Spawn task that will be aborted if this builder (or the stream /// built from it) are dropped /// @@ -73,6 +74,45 @@ impl RecordBatchReceiverStreamBuilder { self.join_set.spawn(task); } + /// runs the input_partition of the `input` ExecutionPlan on the + /// tokio threadpool and writes its outputs to this stream + pub(crate) fn run_input( + &mut self, + input: Arc, + partition: usize, + context: Arc, + ) { + let output = self.tx(); + + self.spawn(async move { + let mut stream = match input.execute(partition, context) { + Err(e) => { + // If send fails, plan being torn down, + // there is no place to send the error. + output.send(Err(e)).await.ok(); + debug!( + "Stopping execution: error executing input: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + Ok(stream) => stream, + }; + + while let Some(item) = stream.next().await { + // If send fails, plan being torn down, + // there is no place to send the error. + if output.send(item).await.is_err() { + debug!( + "Stopping execution: output is gone, plan cancelling: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + } + }); + } + /// Create a stream of all `RecordBatch`es written to `tx` pub fn build(self) -> SendableRecordBatchStream { let Self {