diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index 39d715761f89..9be68337b2b1 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -29,10 +29,9 @@ use crate::{ }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use futures::StreamExt; -use tokio::task::JoinSet; use super::expressions::PhysicalSortExpr; -use super::stream::RecordBatchStreamAdapter; +use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use super::{Distribution, SendableRecordBatchStream}; use datafusion_execution::TaskContext; @@ -121,23 +120,15 @@ impl ExecutionPlan for AnalyzeExec { // Gather futures that will run each input partition in // parallel (on a separate tokio task) using a JoinSet to // cancel outstanding futures on drop - let mut set = JoinSet::new(); let num_input_partitions = self.input.output_partitioning().partition_count(); + let mut builder = + RecordBatchReceiverStream::builder(self.schema(), num_input_partitions); for input_partition in 0..num_input_partitions { - let input_stream = self.input.execute(input_partition, context.clone()); - - set.spawn(async move { - let mut total_rows = 0; - let mut input_stream = input_stream?; - while let Some(batch) = input_stream.next().await { - let batch = batch?; - total_rows += batch.num_rows(); - } - Ok(total_rows) as Result - }); + builder.run_input(self.input.clone(), input_partition, context.clone()); } + // Create future that computes thefinal output let start = Instant::now(); let captured_input = self.input.clone(); let captured_schema = self.schema.clone(); @@ -146,18 +137,11 @@ impl ExecutionPlan for AnalyzeExec { // future that gathers the results from all the tasks in the // JoinSet that computes the overall row count and final // record batch + let mut input_stream = builder.build(); let output = async move { let mut total_rows = 0; - while let Some(res) = set.join_next().await { - // translate join errors (aka task panic's) into ExecutionErrors - match res { - Ok(row_count) => total_rows += row_count?, - Err(e) => { - return Err(DataFusionError::Execution(format!( - "Join error in AnalyzeExec: {e}" - ))) - } - } + while let Some(batch) = input_stream.next().await.transpose()? { + total_rows += batch.num_rows(); } let duration = Instant::now() - start; diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index 11d7021ca999..66700cd9e748 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -20,25 +20,19 @@ use std::any::Any; use std::sync::Arc; -use std::task::Poll; - -use futures::Stream; -use tokio::sync::mpsc; use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use super::common::AbortOnDropMany; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use super::{RecordBatchStream, Statistics}; +use super::stream::{ObservedStream, RecordBatchReceiverStream}; +use super::Statistics; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, }; use super::SendableRecordBatchStream; -use crate::physical_plan::common::spawn_execution; use datafusion_execution::TaskContext; /// Merge execution plan executes partitions in parallel and combines them into a single @@ -137,27 +131,17 @@ impl ExecutionPlan for CoalescePartitionsExec { // use a stream that allows each sender to put in at // least one result in an attempt to maximize // parallelism. - let (sender, receiver) = - mpsc::channel::>(input_partitions); + let mut builder = + RecordBatchReceiverStream::builder(self.schema(), input_partitions); // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. - let mut join_handles = Vec::with_capacity(input_partitions); for part_i in 0..input_partitions { - join_handles.push(spawn_execution( - self.input.clone(), - sender.clone(), - part_i, - context.clone(), - )); + builder.run_input(self.input.clone(), part_i, context.clone()); } - Ok(Box::pin(MergeStream { - input: receiver, - schema: self.schema(), - baseline_metrics, - drop_helper: AbortOnDropMany(join_handles), - })) + let stream = builder.build(); + Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))) } } } @@ -183,32 +167,6 @@ impl ExecutionPlan for CoalescePartitionsExec { } } -struct MergeStream { - schema: SchemaRef, - input: mpsc::Receiver>, - baseline_metrics: BaselineMetrics, - #[allow(unused)] - drop_helper: AbortOnDropMany<()>, -} - -impl Stream for MergeStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let poll = self.input.poll_recv(cx); - self.baseline_metrics.record_poll(poll) - } -} - -impl RecordBatchStream for MergeStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[cfg(test)] mod tests { @@ -218,7 +176,9 @@ mod tests { use super::*; use crate::physical_plan::{collect, common}; use crate::prelude::SessionContext; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, PanicExec, + }; use crate::test::{self, assert_is_pending}; #[tokio::test] @@ -270,4 +230,19 @@ mod tests { Ok(()) } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn test_panic() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2)); + let coalesce_partitions_exec = + Arc::new(CoalescePartitionsExec::new(panicking_exec)); + + collect(coalesce_partitions_exec, task_ctx).await.unwrap(); + } } diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index 98239557cb53..2f296ce462f9 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -21,15 +21,13 @@ use super::SendableRecordBatchStream; use crate::error::{DataFusionError, Result}; use crate::execution::memory_pool::MemoryReservation; use crate::physical_plan::stream::RecordBatchReceiverStream; -use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics}; +use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; -use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{Future, StreamExt, TryStreamExt}; -use log::debug; use parking_lot::Mutex; use pin_project_lite::pin_project; use std::fs; @@ -37,7 +35,6 @@ use std::fs::{metadata, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::mpsc; use tokio::task::JoinHandle; /// [`MemoryReservation`] used across query execution streams @@ -96,42 +93,6 @@ fn build_file_list_recurse( Ok(()) } -/// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender -pub(crate) fn spawn_execution( - input: Arc, - output: mpsc::Sender>, - partition: usize, - context: Arc, -) -> JoinHandle<()> { - tokio::spawn(async move { - let mut stream = match input.execute(partition, context) { - Err(e) => { - // If send fails, plan being torn down, - // there is no place to send the error. - output.send(Err(e)).await.ok(); - debug!( - "Stopping execution: error executing input: {}", - displayable(input.as_ref()).one_line() - ); - return; - } - Ok(stream) => stream, - }; - - while let Some(item) = stream.next().await { - // If send fails, plan being torn down, - // there is no place to send the error. - if output.send(item).await.is_err() { - debug!( - "Stopping execution: output is gone, plan cancelling: {}", - displayable(input.as_ref()).one_line() - ); - return; - } - } - }) -} - /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub(crate) fn spawn_buffered( @@ -139,14 +100,15 @@ pub(crate) fn spawn_buffered( buffer: usize, ) -> SendableRecordBatchStream { // Use tokio only if running from a tokio context (#2201) - let handle = match tokio::runtime::Handle::try_current() { - Ok(handle) => handle, - Err(_) => return input, + if tokio::runtime::Handle::try_current().is_err() { + return input; }; - let schema = input.schema(); - let (sender, receiver) = mpsc::channel(buffer); - let join = handle.spawn(async move { + let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + + let sender = builder.tx(); + + builder.spawn(async move { while let Some(item) = input.next().await { if sender.send(item).await.is_err() { return; @@ -154,7 +116,7 @@ pub(crate) fn spawn_buffered( } }); - RecordBatchReceiverStream::create(&schema, receiver, join) + builder.build() } /// Computes the statistics for an in-memory RecordBatch diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 3e3a79495be3..53177310cc9b 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -52,7 +52,7 @@ use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; use tempfile::NamedTempFile; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::Sender; use tokio::task; struct ExternalSorterMetrics { @@ -373,18 +373,16 @@ fn read_spill_as_stream( path: NamedTempFile, schema: SchemaRef, ) -> Result { - let (sender, receiver): (Sender>, Receiver>) = - tokio::sync::mpsc::channel(2); - let join_handle = task::spawn_blocking(move || { + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + let sender = builder.tx(); + + builder.spawn_blocking(move || { if let Err(e) = read_spill(sender, path.path()) { error!("Failure while reading spill file: {:?}. Error: {}", path, e); } }); - Ok(RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - )) + + Ok(builder.build()) } fn write_sorted( diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 95cc23a20cb5..eb2725ade1ef 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -792,9 +792,12 @@ mod tests { let mut streams = Vec::with_capacity(partition_count); for partition in 0..partition_count { - let (sender, receiver) = tokio::sync::mpsc::channel(1); + let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + + let sender = builder.tx(); + let mut stream = batches.execute(partition, task_ctx.clone()).unwrap(); - let join_handle = tokio::spawn(async move { + builder.spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); // This causes the MergeStream to wait for more input @@ -802,11 +805,7 @@ mod tests { } }); - streams.push(RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - )); + streams.push(builder.build()); } let metrics = ExecutionPlanMetricsSet::new(); diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 2190022bc505..75a0f45e1ee2 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -17,43 +17,205 @@ //! Stream wrappers for physical operators +use std::sync::Arc; + use crate::error::Result; +use crate::physical_plan::displayable; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use futures::{Stream, StreamExt}; +use datafusion_common::DataFusionError; +use datafusion_execution::TaskContext; +use futures::stream::BoxStream; +use futures::{Future, Stream, StreamExt}; +use log::debug; use pin_project_lite::pin_project; -use tokio::task::JoinHandle; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::task::JoinSet; use tokio_stream::wrappers::ReceiverStream; -use super::common::AbortOnDropSingle; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::metrics::BaselineMetrics; +use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -/// Adapter for a tokio [`ReceiverStream`] that implements the -/// [`SendableRecordBatchStream`] -/// interface -pub struct RecordBatchReceiverStream { +/// Builder for [`RecordBatchReceiverStream`] that propagates errors +/// and panic's correctly. +/// +/// [`RecordBatchReceiverStream`] is used to spawn one or more tasks +/// that produce `RecordBatch`es and send them to a single +/// `Receiver` which can improve parallelism. +/// +/// This also handles propagating panic`s and canceling the tasks. +pub struct RecordBatchReceiverStreamBuilder { + tx: Sender>, + rx: Receiver>, schema: SchemaRef, + join_set: JoinSet<()>, +} + +impl RecordBatchReceiverStreamBuilder { + /// create new channels with the specified buffer size + pub fn new(schema: SchemaRef, capacity: usize) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); - inner: ReceiverStream>, + Self { + tx, + rx, + schema, + join_set: JoinSet::new(), + } + } + + /// Get a handle for sending [`RecordBatch`]es to the output + pub fn tx(&self) -> Sender> { + self.tx.clone() + } + + /// Spawn task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn(&mut self, task: F) + where + F: Future, + F: Send + 'static, + { + self.join_set.spawn(task); + } + + /// Spawn a blocking task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn_blocking(&mut self, f: F) + where + F: FnOnce(), + F: Send + 'static, + { + self.join_set.spawn_blocking(f); + } + + /// runs the input_partition of the `input` ExecutionPlan on the + /// tokio threadpool and writes its outputs to this stream + /// + /// If the input partition produces an error, the error will be + /// sent to the output stream and no further results are sent. + pub(crate) fn run_input( + &mut self, + input: Arc, + partition: usize, + context: Arc, + ) { + let output = self.tx(); + + self.spawn(async move { + let mut stream = match input.execute(partition, context) { + Err(e) => { + // If send fails, the plan being torn down, there + // is no place to send the error and no reason to continue. + output.send(Err(e)).await.ok(); + debug!( + "Stopping execution: error executing input: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + Ok(stream) => stream, + }; + + // Transfer batches from inner stream to the output tx + // immediately. + while let Some(item) = stream.next().await { + let is_err = item.is_err(); + + // If send fails, plan being torn down, there is no + // place to send the error and no reason to continue. + if output.send(item).await.is_err() { + debug!( + "Stopping execution: output is gone, plan cancelling: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + + // stop after the first error is encontered (don't + // drive all streams to completion) + if is_err { + debug!( + "Stopping execution: plan returned error: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + } + }); + } + + /// Create a stream of all `RecordBatch`es written to `tx` + pub fn build(self) -> SendableRecordBatchStream { + let Self { + tx, + rx, + schema, + mut join_set, + } = self; - #[allow(dead_code)] - drop_helper: AbortOnDropSingle<()>, + // don't need tx + drop(tx); + + // future that checks the result of the join set, and propagates panic if seen + let check = async move { + while let Some(result) = join_set.join_next().await { + match result { + Ok(()) => continue, // nothing to report + // This means a tokio task error, likely a panic + Err(e) => { + if e.is_panic() { + // resume on the main thread + std::panic::resume_unwind(e.into_panic()); + } else { + // This should only occur if the task is + // cancelled, which would only occur if + // the JoinSet were aborted, which in turn + // would imply that the receiver has been + // dropped and this code is not running + return Some(Err(DataFusionError::Internal(format!( + "Non Panic Task error: {e}" + )))); + } + } + } + } + None + }; + + let check_stream = futures::stream::once(check) + // unwrap Option / only return the error + .filter_map(|item| async move { item }); + + // Merge the streams together so whichever is ready first + // produces the batch + let inner = + futures::stream::select(ReceiverStream::new(rx), check_stream).boxed(); + + Box::pin(RecordBatchReceiverStream { schema, inner }) + } +} + +/// Adapter for a tokio [`ReceiverStream`] that implements the +/// [`SendableRecordBatchStream`] interface and propagates panics and +/// errors. Use [`Self::builder`] to construct one. +pub struct RecordBatchReceiverStream { + schema: SchemaRef, + inner: BoxStream<'static, Result>, } impl RecordBatchReceiverStream { - /// Construct a new [`RecordBatchReceiverStream`] which will send - /// batches of the specified schema from `inner` - pub fn create( - schema: &SchemaRef, - rx: tokio::sync::mpsc::Receiver>, - join_handle: JoinHandle<()>, - ) -> SendableRecordBatchStream { - let schema = schema.clone(); - let inner = ReceiverStream::new(rx); - Box::pin(Self { - schema, - inner, - drop_helper: AbortOnDropSingle::new(join_handle), - }) + /// Create a builder with an internal buffer of capacity batches. + pub fn builder( + schema: SchemaRef, + capacity: usize, + ) -> RecordBatchReceiverStreamBuilder { + RecordBatchReceiverStreamBuilder::new(schema, capacity) } } @@ -126,3 +288,173 @@ where self.schema.clone() } } + +/// Stream wrapper that records `BaselineMetrics` for a particular +/// `[SendableRecordBatchStream]` (likely a partition) +pub(crate) struct ObservedStream { + inner: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, +} + +impl ObservedStream { + pub fn new( + inner: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + inner, + baseline_metrics, + } + } +} + +impl RecordBatchStream for ObservedStream { + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.inner.schema() + } +} + +impl futures::Stream for ObservedStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let poll = self.inner.poll_next_unpin(cx); + self.baseline_metrics.record_poll(poll) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_schema::{DataType, Field, Schema}; + + use crate::{ + execution::context::SessionContext, + test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, + }, + }; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])) + } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn record_batch_receiver_stream_propagates_panics() { + let schema = schema(); + + let num_partitions = 10; + let input = PanicExec::new(schema.clone(), num_partitions); + consume(input, 10).await + } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic: 1")] + async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { + let schema = schema(); + + // make 2 partitions, second partition panics before the first + let num_partitions = 2; + let input = PanicExec::new(schema.clone(), num_partitions) + .with_partition_panic(0, 10) + .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) + + // ensure that the panic results in an early shutdown (that + // everything stops after the first panic). + + // Since the stream reads every other batch: (0,1,0,1,0,panic) + // so should not exceed 5 batches prior to the panic + let max_batches = 5; + consume(input, max_batches).await + } + + #[tokio::test] + async fn record_batch_receiver_stream_drop_cancel() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = schema(); + + // Make an input that never proceeds + let input = BlockingExec::new(schema.clone(), 1); + let refs = input.refs(); + + // Configure a RecordBatchReceiverStream to consume the input + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + builder.run_input(Arc::new(input), 0, task_ctx.clone()); + let stream = builder.build(); + + // input should still be present + assert!(std::sync::Weak::strong_count(&refs) > 0); + + // drop the stream, ensure the refs go to zero + drop(stream); + assert_strong_count_converges_to_zero(refs).await; + } + + #[tokio::test] + /// Ensure that if an error is received in one stream, the + /// `RecordBatchReceiverStream` stops early and does not drive + /// other streams to completion. + async fn record_batch_receiver_stream_error_does_not_drive_completion() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = schema(); + + // make an input that will error twice + let error_stream = MockExec::new( + vec![ + Err(DataFusionError::Execution("Test1".to_string())), + Err(DataFusionError::Execution("Test2".to_string())), + ], + schema.clone(), + ) + .with_use_task(false); + + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); + let mut stream = builder.build(); + + // get the first result, which should be an error + let first_batch = stream.next().await.unwrap(); + let first_err = first_batch.unwrap_err(); + assert_eq!(first_err.to_string(), "Execution error: Test1"); + + // There should be no more batches produced (should not get the second error) + assert!(stream.next().await.is_none()); + } + + /// Consumes all the input's partitions into a + /// RecordBatchReceiverStream and runs it to completion + /// + /// panic's if more than max_batches is seen, + async fn consume(input: PanicExec, max_batches: usize) { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let input = Arc::new(input); + let num_partitions = input.output_partitioning().partition_count(); + + // Configure a RecordBatchReceiverStream to consume all the input partitions + let mut builder = + RecordBatchReceiverStream::builder(input.schema(), num_partitions); + for partition in 0..num_partitions { + builder.run_input(input.clone(), partition, task_ctx.clone()); + } + let mut stream = builder.build(); + + // drain the stream until it is complete, panic'ing on error + let mut num_batches = 0; + while let Some(next) = stream.next().await { + next.unwrap(); + num_batches += 1; + assert!( + num_batches < max_batches, + "Got the limit of {num_batches} batches before seeing panic" + ); + } + } +} diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index 5cf25fbe021c..f2b936cf532b 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -30,7 +30,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{DFSchemaRef, DataFusionError}; -use futures::{Stream, StreamExt}; +use futures::Stream; use itertools::Itertools; use log::{debug, trace, warn}; @@ -41,6 +41,7 @@ use super::{ SendableRecordBatchStream, Statistics, }; use crate::physical_plan::common::get_meet_of_orderings; +use crate::physical_plan::stream::ObservedStream; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -560,40 +561,6 @@ impl Stream for CombinedRecordBatchStream { } } -/// Stream wrapper that records `BaselineMetrics` for a particular -/// partition -struct ObservedStream { - inner: SendableRecordBatchStream, - baseline_metrics: BaselineMetrics, -} - -impl ObservedStream { - fn new(inner: SendableRecordBatchStream, baseline_metrics: BaselineMetrics) -> Self { - Self { - inner, - baseline_metrics, - } - } -} - -impl RecordBatchStream for ObservedStream { - fn schema(&self) -> arrow::datatypes::SchemaRef { - self.inner.schema() - } -} - -impl futures::Stream for ObservedStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let poll = self.inner.poll_next_unpin(cx); - self.baseline_metrics.record_poll(poll) - } -} - fn col_stats_union( mut left: ColumnStatistics, right: ColumnStatistics, diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index bce7d08a5c56..41a0a1b4d084 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -31,7 +31,6 @@ use arrow::{ }; use futures::Stream; -use crate::execution::context::TaskContext; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -41,6 +40,9 @@ use crate::{ error::{DataFusionError, Result}, physical_plan::stream::RecordBatchReceiverStream, }; +use crate::{ + execution::context::TaskContext, physical_plan::stream::RecordBatchStreamAdapter, +}; /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] @@ -114,22 +116,40 @@ impl RecordBatchStream for TestStream { } } -/// A Mock ExecutionPlan that can be used for writing tests of other ExecutionPlans -/// +/// A Mock ExecutionPlan that can be used for writing tests of other +/// ExecutionPlans #[derive(Debug)] pub struct MockExec { /// the results to send back data: Vec>, schema: SchemaRef, + /// if true (the default), sends data using a separate task to to ensure the + /// batches are not available without this stream yielding first + use_task: bool, } impl MockExec { - /// Create a new exec with a single partition that returns the - /// record batches in this Exec. Note the batches are not produced - /// immediately (the caller has to actually yield and another task - /// must run) to ensure any poll loops are correct. + /// Create a new `MockExec` with a single partition that returns + /// the specified `Results`s. + /// + /// By default, the batches are not produced immediately (the + /// caller has to actually yield and another task must run) to + /// ensure any poll loops are correct. This behavior can be + /// changed with `with_use_task` pub fn new(data: Vec>, schema: SchemaRef) -> Self { - Self { data, schema } + Self { + data, + schema, + use_task: true, + } + } + + /// If `use_task` is true (the default) then the batches are sent + /// back using a separate task to ensure the underlying stream is + /// not immediately ready + pub fn with_use_task(mut self, use_task: bool) -> Self { + self.use_task = use_task; + self } } @@ -179,26 +199,30 @@ impl ExecutionPlan for MockExec { }) .collect(); - let (tx, rx) = tokio::sync::mpsc::channel(2); - - // task simply sends data in order but in a separate - // thread (to ensure the batches are not available without the - // DelayedStream yielding). - let join_handle = tokio::task::spawn(async move { - for batch in data { - println!("Sending batch via delayed stream"); - if let Err(e) = tx.send(batch).await { - println!("ERROR batch via delayed stream: {e}"); + if self.use_task { + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); + // send data in order but in a separate task (to ensure + // the batches are not available without the stream + // yielding). + let tx = builder.tx(); + builder.spawn(async move { + for batch in data { + println!("Sending batch via delayed stream"); + if let Err(e) = tx.send(batch).await { + println!("ERROR batch via delayed stream: {e}"); + } } - } - }); - - // returned stream simply reads off the rx stream - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + }); + // returned stream simply reads off the rx stream + Ok(builder.build()) + } else { + // make an input that will error + let stream = futures::stream::iter(data); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } } fn fmt_as( @@ -307,12 +331,13 @@ impl ExecutionPlan for BarrierExec { ) -> Result { assert!(partition < self.data.len()); - let (tx, rx) = tokio::sync::mpsc::channel(2); + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); let b = self.barrier.clone(); - let join_handle = tokio::task::spawn(async move { + let tx = builder.tx(); + builder.spawn(async move { println!("Partition {partition} waiting on barrier"); b.wait().await; for batch in data { @@ -324,11 +349,7 @@ impl ExecutionPlan for BarrierExec { }); // returned stream simply reads off the rx stream - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + Ok(builder.build()) } fn fmt_as( @@ -643,3 +664,144 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { .await .unwrap(); } + +/// + +/// Execution plan that emits streams that panics. +/// +/// This is useful to test panic handling of certain execution plans. +#[derive(Debug)] +pub struct PanicExec { + /// Schema that is mocked by this plan. + schema: SchemaRef, + + /// Number of output partitions. Each partition will produce this + /// many empty output record batches prior to panicing + batches_until_panics: Vec, +} + +impl PanicExec { + /// Create new [`PanickingExec`] with a give schema and number of + /// partitions, which will each panic immediately. + pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { + Self { + schema, + batches_until_panics: vec![0; n_partitions], + } + } + + /// Set the number of batches prior to panic for a partition + pub fn with_partition_panic(mut self, partition: usize, count: usize) -> Self { + self.batches_until_panics[partition] = count; + self + } +} + +impl ExecutionPlan for PanicExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + let num_partitions = self.batches_until_panics.len(); + Partitioning::UnknownPartitioning(num_partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(PanicStream { + partition, + batches_until_panic: self.batches_until_panics[partition], + schema: Arc::clone(&self.schema), + ready: false, + })) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "PanickingExec",) + } + } + } + + fn statistics(&self) -> Statistics { + unimplemented!() + } +} + +/// A [`RecordBatchStream`] that yields every other batch and panics +/// after `batches_until_panic` batches have been produced. +/// +/// Useful for testing the behavior of streams on panic +#[derive(Debug)] +struct PanicStream { + /// Which partition was this + partition: usize, + /// How may batches will be produced until panic + batches_until_panic: usize, + /// Schema mocked by this stream. + schema: SchemaRef, + /// Should we return ready ? + ready: bool, +} + +impl Stream for PanicStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if self.batches_until_panic > 0 { + if self.ready { + self.batches_until_panic -= 1; + self.ready = false; + let batch = RecordBatch::new_empty(self.schema.clone()); + return Poll::Ready(Some(Ok(batch))); + } else { + self.ready = true; + // get called again + cx.waker().clone().wake(); + return Poll::Pending; + } + } + panic!("PanickingStream did panic: {}", self.partition) + } +} + +impl RecordBatchStream for PanicStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +}