From 0fbc3dd4fc731628803091d8ff8dec27fd62c973 Mon Sep 17 00:00:00 2001 From: Nicolae Vartolomei Date: Thu, 25 May 2023 21:31:22 +0100 Subject: [PATCH 01/14] Propagate panics Another try for fixing #3104. RepartitionExec might need a similar fix. --- .../src/physical_plan/coalesce_partitions.rs | 55 +++++++-- datafusion/core/src/physical_plan/common.rs | 9 +- datafusion/core/src/test/exec.rs | 105 ++++++++++++++++++ 3 files changed, 156 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index fe667d1e6e2a..0bc0858c6d2e 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -19,16 +19,17 @@ //! into a single partition use std::any::Any; +use std::panic; use std::sync::Arc; use std::task::Poll; -use futures::Stream; +use futures::{FutureExt, Stream}; use tokio::sync::mpsc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use tokio::task::JoinSet; -use super::common::AbortOnDropMany; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, Statistics}; @@ -142,21 +143,22 @@ impl ExecutionPlan for CoalescePartitionsExec { // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. - let mut join_handles = Vec::with_capacity(input_partitions); + let mut tasks = JoinSet::new(); for part_i in 0..input_partitions { - join_handles.push(spawn_execution( + spawn_execution( + &mut tasks, self.input.clone(), sender.clone(), part_i, context.clone(), - )); + ); } Ok(Box::pin(MergeStream { input: receiver, schema: self.schema(), baseline_metrics, - drop_helper: AbortOnDropMany(join_handles), + tasks, })) } } @@ -187,8 +189,7 @@ struct MergeStream { schema: SchemaRef, input: mpsc::Receiver>, baseline_metrics: BaselineMetrics, - #[allow(unused)] - drop_helper: AbortOnDropMany<()>, + tasks: JoinSet<()>, } impl Stream for MergeStream { @@ -199,6 +200,25 @@ impl Stream for MergeStream { cx: &mut std::task::Context<'_>, ) -> Poll> { let poll = self.input.poll_recv(cx); + + // If the input stream is done, wait for all tasks to finish and return + // the failure if any. + if let Poll::Ready(None) = poll { + match Box::pin(self.tasks.join_next()).poll_unpin(cx) { + Poll::Ready(task_poll) => { + if let Some(Err(e)) = task_poll { + if e.is_panic() { + panic::resume_unwind(e.into_panic()); + } + return Poll::Ready(Some(Err(DataFusionError::Execution( + format!("{e:?}"), + )))); + } + } + Poll::Pending => {} + } + } + self.baseline_metrics.record_poll(poll) } } @@ -218,7 +238,9 @@ mod tests { use super::*; use crate::physical_plan::{collect, common}; use crate::prelude::SessionContext; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, PanickingExec, + }; use crate::test::{self, assert_is_pending}; #[tokio::test] @@ -270,4 +292,19 @@ mod tests { Ok(()) } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn test_panic() -> () { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let panicking_exec = Arc::new(PanickingExec::new(Arc::clone(&schema), 2)); + let coalesce_partitions_exec = + Arc::new(CoalescePartitionsExec::new(panicking_exec)); + + collect(coalesce_partitions_exec, task_ctx).await.unwrap(); + } } diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index e766c225b51c..a9c267f123a8 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -38,7 +38,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::mpsc; -use tokio::task::JoinHandle; +use tokio::task::{JoinHandle, JoinSet}; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -98,12 +98,13 @@ fn build_file_list_recurse( /// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender pub(crate) fn spawn_execution( + join_set: &mut JoinSet<()>, input: Arc, output: mpsc::Sender>, partition: usize, context: Arc, -) -> JoinHandle<()> { - tokio::spawn(async move { +) { + join_set.spawn(async move { let mut stream = match input.execute(partition, context) { Err(e) => { // If send fails, plan being torn down, @@ -129,7 +130,7 @@ pub(crate) fn spawn_execution( return; } } - }) + }); } /// If running in a tokio context spawns the execution of `stream` to a separate task diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index bce7d08a5c56..13f3dc6a16c8 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -643,3 +643,108 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { .await .unwrap(); } + +/// Execution plan that emits streams that panics. +/// +/// This is useful to test panic handling of certain execution plans. +#[derive(Debug)] +pub struct PanickingExec { + /// Schema that is mocked by this plan. + schema: SchemaRef, + + /// Number of output partitions. + n_partitions: usize, +} + +impl PanickingExec { + /// Create new [`PanickingExec`] with a give schema and number of partitions. + pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { + Self { + schema, + n_partitions, + } + } +} + +impl ExecutionPlan for PanickingExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.n_partitions) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(PanickingStream { + schema: Arc::clone(&self.schema), + })) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "PanickingExec",) + } + } + } + + fn statistics(&self) -> Statistics { + unimplemented!() + } +} + +/// A [`RecordBatchStream`] that panics on first poll. +#[derive(Debug)] +pub struct PanickingStream { + /// Schema mocked by this stream. + schema: SchemaRef, +} + +impl Stream for PanickingStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + panic!("PanickingStream did panic") + } +} + +impl RecordBatchStream for PanickingStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} From 742597a0de97fdc06d50bc3a41d6c736731e5587 Mon Sep 17 00:00:00 2001 From: Nicolae Vartolomei Date: Fri, 26 May 2023 13:38:57 +0100 Subject: [PATCH 02/14] avoid allocation by pinning on the stack instead --- datafusion/core/src/physical_plan/coalesce_partitions.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index 0bc0858c6d2e..3fcae0174114 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -23,7 +23,7 @@ use std::panic; use std::sync::Arc; use std::task::Poll; -use futures::{FutureExt, Stream}; +use futures::{Future, Stream}; use tokio::sync::mpsc; use arrow::datatypes::SchemaRef; @@ -204,7 +204,10 @@ impl Stream for MergeStream { // If the input stream is done, wait for all tasks to finish and return // the failure if any. if let Poll::Ready(None) = poll { - match Box::pin(self.tasks.join_next()).poll_unpin(cx) { + let fut = self.tasks.join_next(); + tokio::pin!(fut); + + match fut.poll(cx) { Poll::Ready(task_poll) => { if let Some(Err(e)) = task_poll { if e.is_panic() { @@ -295,7 +298,7 @@ mod tests { #[tokio::test] #[should_panic(expected = "PanickingStream did panic")] - async fn test_panic() -> () { + async fn test_panic() { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let schema = From e1c827ad24ab6ff371a61d99b6ab5b87a3ce1d7a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 31 May 2023 15:26:53 -0400 Subject: [PATCH 03/14] Consolidate panic propagation into RecordBatchReceiverStream --- datafusion/core/src/physical_plan/analyze.rs | 33 +-- .../src/physical_plan/coalesce_partitions.rs | 83 +----- datafusion/core/src/physical_plan/common.rs | 59 +--- .../core/src/physical_plan/sorts/sort.rs | 16 +- .../sorts/sort_preserving_merge.rs | 13 +- datafusion/core/src/physical_plan/stream.rs | 274 ++++++++++++++++-- datafusion/core/src/physical_plan/union.rs | 37 +-- datafusion/core/src/test/exec.rs | 87 ++++-- 8 files changed, 348 insertions(+), 254 deletions(-) diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index 84d74c512b54..52bf16629b4d 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -29,10 +29,9 @@ use crate::{ }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use futures::StreamExt; -use tokio::task::JoinSet; use super::expressions::PhysicalSortExpr; -use super::stream::RecordBatchStreamAdapter; +use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use super::{Distribution, SendableRecordBatchStream}; use crate::execution::context::TaskContext; @@ -121,23 +120,15 @@ impl ExecutionPlan for AnalyzeExec { // Gather futures that will run each input partition in // parallel (on a separate tokio task) using a JoinSet to // cancel outstanding futures on drop - let mut set = JoinSet::new(); let num_input_partitions = self.input.output_partitioning().partition_count(); + let mut builder = + RecordBatchReceiverStream::builder(self.schema(), num_input_partitions); for input_partition in 0..num_input_partitions { - let input_stream = self.input.execute(input_partition, context.clone()); - - set.spawn(async move { - let mut total_rows = 0; - let mut input_stream = input_stream?; - while let Some(batch) = input_stream.next().await { - let batch = batch?; - total_rows += batch.num_rows(); - } - Ok(total_rows) as Result - }); + builder.run_input(self.input.clone(), input_partition, context.clone()); } + // Create future that computes thefinal output let start = Instant::now(); let captured_input = self.input.clone(); let captured_schema = self.schema.clone(); @@ -146,18 +137,12 @@ impl ExecutionPlan for AnalyzeExec { // future that gathers the results from all the tasks in the // JoinSet that computes the overall row count and final // record batch + let mut input_stream = builder.build(); let output = async move { let mut total_rows = 0; - while let Some(res) = set.join_next().await { - // translate join errors (aka task panic's) into ExecutionErrors - match res { - Ok(row_count) => total_rows += row_count?, - Err(e) => { - return Err(DataFusionError::Execution(format!( - "Join error in AnalyzeExec: {e}" - ))) - } - } + while let Some(batch) = input_stream.next().await { + let batch = batch?; + total_rows += batch.num_rows(); } let duration = Instant::now() - start; diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index 3fcae0174114..920058bbb41a 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -19,20 +19,14 @@ //! into a single partition use std::any::Any; -use std::panic; use std::sync::Arc; -use std::task::Poll; - -use futures::{Future, Stream}; -use tokio::sync::mpsc; use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use tokio::task::JoinSet; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use super::{RecordBatchStream, Statistics}; +use super::stream::{ObservedStream, RecordBatchReceiverStream}; +use super::Statistics; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, @@ -40,7 +34,6 @@ use crate::physical_plan::{ use super::SendableRecordBatchStream; use crate::execution::context::TaskContext; -use crate::physical_plan::common::spawn_execution; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -138,28 +131,17 @@ impl ExecutionPlan for CoalescePartitionsExec { // use a stream that allows each sender to put in at // least one result in an attempt to maximize // parallelism. - let (sender, receiver) = - mpsc::channel::>(input_partitions); + let mut builder = + RecordBatchReceiverStream::builder(self.schema(), input_partitions); // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. - let mut tasks = JoinSet::new(); for part_i in 0..input_partitions { - spawn_execution( - &mut tasks, - self.input.clone(), - sender.clone(), - part_i, - context.clone(), - ); + builder.run_input(self.input.clone(), part_i, context.clone()); } - Ok(Box::pin(MergeStream { - input: receiver, - schema: self.schema(), - baseline_metrics, - tasks, - })) + let stream = builder.build(); + Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))) } } } @@ -185,53 +167,6 @@ impl ExecutionPlan for CoalescePartitionsExec { } } -struct MergeStream { - schema: SchemaRef, - input: mpsc::Receiver>, - baseline_metrics: BaselineMetrics, - tasks: JoinSet<()>, -} - -impl Stream for MergeStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let poll = self.input.poll_recv(cx); - - // If the input stream is done, wait for all tasks to finish and return - // the failure if any. - if let Poll::Ready(None) = poll { - let fut = self.tasks.join_next(); - tokio::pin!(fut); - - match fut.poll(cx) { - Poll::Ready(task_poll) => { - if let Some(Err(e)) = task_poll { - if e.is_panic() { - panic::resume_unwind(e.into_panic()); - } - return Poll::Ready(Some(Err(DataFusionError::Execution( - format!("{e:?}"), - )))); - } - } - Poll::Pending => {} - } - } - - self.baseline_metrics.record_poll(poll) - } -} - -impl RecordBatchStream for MergeStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[cfg(test)] mod tests { @@ -242,7 +177,7 @@ mod tests { use crate::physical_plan::{collect, common}; use crate::prelude::SessionContext; use crate::test::exec::{ - assert_strong_count_converges_to_zero, BlockingExec, PanickingExec, + assert_strong_count_converges_to_zero, BlockingExec, PanicingExec, }; use crate::test::{self, assert_is_pending}; @@ -304,7 +239,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let panicking_exec = Arc::new(PanickingExec::new(Arc::clone(&schema), 2)); + let panicking_exec = Arc::new(PanicingExec::new(Arc::clone(&schema), 2)); let coalesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(panicking_exec)); diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index a9c267f123a8..2f296ce462f9 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -19,17 +19,15 @@ use super::SendableRecordBatchStream; use crate::error::{DataFusionError, Result}; -use crate::execution::context::TaskContext; use crate::execution::memory_pool::MemoryReservation; use crate::physical_plan::stream::RecordBatchReceiverStream; -use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics}; +use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{Future, StreamExt, TryStreamExt}; -use log::debug; use parking_lot::Mutex; use pin_project_lite::pin_project; use std::fs; @@ -37,8 +35,7 @@ use std::fs::{metadata, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::sync::mpsc; -use tokio::task::{JoinHandle, JoinSet}; +use tokio::task::JoinHandle; /// [`MemoryReservation`] used across query execution streams pub(crate) type SharedMemoryReservation = Arc>; @@ -96,43 +93,6 @@ fn build_file_list_recurse( Ok(()) } -/// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender -pub(crate) fn spawn_execution( - join_set: &mut JoinSet<()>, - input: Arc, - output: mpsc::Sender>, - partition: usize, - context: Arc, -) { - join_set.spawn(async move { - let mut stream = match input.execute(partition, context) { - Err(e) => { - // If send fails, plan being torn down, - // there is no place to send the error. - output.send(Err(e)).await.ok(); - debug!( - "Stopping execution: error executing input: {}", - displayable(input.as_ref()).one_line() - ); - return; - } - Ok(stream) => stream, - }; - - while let Some(item) = stream.next().await { - // If send fails, plan being torn down, - // there is no place to send the error. - if output.send(item).await.is_err() { - debug!( - "Stopping execution: output is gone, plan cancelling: {}", - displayable(input.as_ref()).one_line() - ); - return; - } - } - }); -} - /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub(crate) fn spawn_buffered( @@ -140,14 +100,15 @@ pub(crate) fn spawn_buffered( buffer: usize, ) -> SendableRecordBatchStream { // Use tokio only if running from a tokio context (#2201) - let handle = match tokio::runtime::Handle::try_current() { - Ok(handle) => handle, - Err(_) => return input, + if tokio::runtime::Handle::try_current().is_err() { + return input; }; - let schema = input.schema(); - let (sender, receiver) = mpsc::channel(buffer); - let join = handle.spawn(async move { + let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + + let sender = builder.tx(); + + builder.spawn(async move { while let Some(item) = input.next().await { if sender.send(item).await.is_err() { return; @@ -155,7 +116,7 @@ pub(crate) fn spawn_buffered( } }); - RecordBatchReceiverStream::create(&schema, receiver, join) + builder.build() } /// Computes the statistics for an in-memory RecordBatch diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 35dac19b27c6..9a066ee097bf 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -52,7 +52,7 @@ use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; use tempfile::NamedTempFile; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::Sender; use tokio::task; struct ExternalSorterMetrics { @@ -373,18 +373,16 @@ fn read_spill_as_stream( path: NamedTempFile, schema: SchemaRef, ) -> Result { - let (sender, receiver): (Sender>, Receiver>) = - tokio::sync::mpsc::channel(2); - let join_handle = task::spawn_blocking(move || { + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + let sender = builder.tx(); + + builder.spawn_blocking(move || { if let Err(e) = read_spill(sender, path.path()) { error!("Failure while reading spill file: {:?}. Error: {}", path, e); } }); - Ok(RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - )) + + Ok(builder.build()) } fn write_sorted( diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index e346eccbeecb..c09ddd9d5bd8 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -792,9 +792,12 @@ mod tests { let mut streams = Vec::with_capacity(partition_count); for partition in 0..partition_count { - let (sender, receiver) = tokio::sync::mpsc::channel(1); + let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + + let sender = builder.tx(); + let mut stream = batches.execute(partition, task_ctx.clone()).unwrap(); - let join_handle = tokio::spawn(async move { + builder.spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); // This causes the MergeStream to wait for more input @@ -802,11 +805,7 @@ mod tests { } }); - streams.push(RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - )); + streams.push(builder.build()); } let metrics = ExecutionPlanMetricsSet::new(); diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 2190022bc505..464b72aee955 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -17,43 +17,173 @@ //! Stream wrappers for physical operators +use std::sync::Arc; + use crate::error::Result; +use crate::physical_plan::displayable; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use futures::{Stream, StreamExt}; +use datafusion_common::DataFusionError; +use datafusion_execution::TaskContext; +use futures::stream::BoxStream; +use futures::{Future, Stream, StreamExt}; +use log::debug; use pin_project_lite::pin_project; -use tokio::task::JoinHandle; +use tokio::task::JoinSet; use tokio_stream::wrappers::ReceiverStream; -use super::common::AbortOnDropSingle; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::metrics::BaselineMetrics; +use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -/// Adapter for a tokio [`ReceiverStream`] that implements the -/// [`SendableRecordBatchStream`] -/// interface -pub struct RecordBatchReceiverStream { +/// Builder for [`RecordBatchReceiverStream`] that propagates errors +/// and panic's correctly. +pub struct RecordBatchReceiverStreamBuilder { + tx: tokio::sync::mpsc::Sender>, + rx: tokio::sync::mpsc::Receiver>, schema: SchemaRef, + join_set: JoinSet<()>, +} + +impl RecordBatchReceiverStreamBuilder { + /// create new channels with the specified buffer size + pub fn new(schema: SchemaRef, capacity: usize) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); - inner: ReceiverStream>, + Self { + tx, + rx, + schema, + join_set: JoinSet::new(), + } + } + + /// Get a handle for sending [`RecordBatch`]es to the output + pub fn tx(&self) -> tokio::sync::mpsc::Sender> { + self.tx.clone() + } + + /// Spawn task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn(&mut self, task: F) + where + F: Future, + F: Send + 'static, + { + self.join_set.spawn(task); + } + + /// Spawn a blocking task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn_blocking(&mut self, f: F) + where + F: FnOnce(), + F: Send + 'static, + { + self.join_set.spawn_blocking(f); + } + + /// runs the input_partition of the `input` ExecutionPlan on the + /// tokio threadpool and writes its outputs to this stream + pub(crate) fn run_input( + &mut self, + input: Arc, + partition: usize, + context: Arc, + ) { + let output = self.tx(); + + self.spawn(async move { + let mut stream = match input.execute(partition, context) { + Err(e) => { + // If send fails, plan being torn down, + // there is no place to send the error. + output.send(Err(e)).await.ok(); + debug!( + "Stopping execution: error executing input: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + Ok(stream) => stream, + }; + + while let Some(item) = stream.next().await { + // If send fails, plan being torn down, + // there is no place to send the error. + if output.send(item).await.is_err() { + debug!( + "Stopping execution: output is gone, plan cancelling: {}", + displayable(input.as_ref()).one_line() + ); + return; + } + } + }); + } - #[allow(dead_code)] - drop_helper: AbortOnDropSingle<()>, + /// Create a stream of all `RecordBatch`es written to `tx` + pub fn build(self) -> SendableRecordBatchStream { + let Self { + tx, + rx, + schema, + mut join_set, + } = self; + + // don't need tx + drop(tx); + + // future that checks the result of the join set + let check = async move { + while let Some(result) = join_set.join_next().await { + match result { + Ok(()) => continue, // nothing to report + // This means a tokio task error, likely a panic + Err(e) => { + if e.is_panic() { + // resume on the main thread + std::panic::resume_unwind(e.into_panic()); + } else { + return Some(Err(DataFusionError::Execution(format!( + "Task error: {e}" + )))); + } + } + } + } + None + }; + + let check_stream = futures::stream::once(check) + // unwrap Option / only return the error + .filter_map(|item| async move { item }); + + let inner = ReceiverStream::new(rx).chain(check_stream).boxed(); + + Box::pin(RecordBatchReceiverStream { schema, inner }) + } +} + +/// Adapter for a tokio [`ReceiverStream`] that implements the +/// [`SendableRecordBatchStream`] interface. Use +/// [`RecordBatchReceiverStreamBuilder`] to construct one. +pub struct RecordBatchReceiverStream { + schema: SchemaRef, + inner: BoxStream<'static, Result>, } impl RecordBatchReceiverStream { - /// Construct a new [`RecordBatchReceiverStream`] which will send - /// batches of the specified schema from `inner` - pub fn create( - schema: &SchemaRef, - rx: tokio::sync::mpsc::Receiver>, - join_handle: JoinHandle<()>, - ) -> SendableRecordBatchStream { - let schema = schema.clone(); - let inner = ReceiverStream::new(rx); - Box::pin(Self { - schema, - inner, - drop_helper: AbortOnDropSingle::new(join_handle), - }) + /// Create a builder with an internal buffer of capacity batches. + pub fn builder( + schema: SchemaRef, + capacity: usize, + ) -> RecordBatchReceiverStreamBuilder { + RecordBatchReceiverStreamBuilder::new(schema, capacity) } } @@ -126,3 +256,97 @@ where self.schema.clone() } } + +/// Stream wrapper that records `BaselineMetrics` for a particular +/// partition +pub(crate) struct ObservedStream { + inner: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, +} + +impl ObservedStream { + pub fn new( + inner: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + inner, + baseline_metrics, + } + } +} + +impl RecordBatchStream for ObservedStream { + fn schema(&self) -> arrow::datatypes::SchemaRef { + self.inner.schema() + } +} + +impl futures::Stream for ObservedStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let poll = self.inner.poll_next_unpin(cx); + self.baseline_metrics.record_poll(poll) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_schema::{DataType, Field, Schema}; + + use crate::{execution::context::SessionContext, test::exec::PanicingExec}; + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn record_batch_receiver_stream_propagates_panics() { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let num_partitions = 10; + let input = PanicingExec::new(schema.clone(), num_partitions); + consume(input).await + } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic: 1")] + async fn record_batch_receiver_stream_propagates_panics_one() { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + // make 2 partitions, second panics before the first + let num_partitions = 2; + let input = PanicingExec::new(schema.clone(), num_partitions) + .with_partition_panic(0, 10) + .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) + + consume(input).await + } + + /// Consumes all the input's partitions into a + /// RecordBatchReceiverStream and runs it to completion + async fn consume(input: PanicingExec) { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let input = Arc::new(input); + let num_partitions = input.output_partitioning().partition_count(); + + // Configure a RecordBatchReceiverStream to consume all the input partitions + let mut builder = + RecordBatchReceiverStream::builder(input.schema(), num_partitions); + for partition in 0..num_partitions { + builder.run_input(input.clone(), partition, task_ctx.clone()); + } + let mut stream = builder.build(); + + // drain the stream until it is complete, panic'ing on error + while let Some(next) = stream.next().await { + next.unwrap(); + } + } +} diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index d1f5ec0c29b0..9c8ba4d29b6f 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -30,7 +30,7 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{DFSchemaRef, DataFusionError}; -use futures::{Stream, StreamExt}; +use futures::Stream; use itertools::Itertools; use log::{debug, trace, warn}; @@ -42,6 +42,7 @@ use super::{ }; use crate::execution::context::TaskContext; use crate::physical_plan::common::get_meet_of_orderings; +use crate::physical_plan::stream::ObservedStream; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -560,40 +561,6 @@ impl Stream for CombinedRecordBatchStream { } } -/// Stream wrapper that records `BaselineMetrics` for a particular -/// partition -struct ObservedStream { - inner: SendableRecordBatchStream, - baseline_metrics: BaselineMetrics, -} - -impl ObservedStream { - fn new(inner: SendableRecordBatchStream, baseline_metrics: BaselineMetrics) -> Self { - Self { - inner, - baseline_metrics, - } - } -} - -impl RecordBatchStream for ObservedStream { - fn schema(&self) -> arrow::datatypes::SchemaRef { - self.inner.schema() - } -} - -impl futures::Stream for ObservedStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let poll = self.inner.poll_next_unpin(cx); - self.baseline_metrics.record_poll(poll) - } -} - fn col_stats_union( mut left: ColumnStatistics, right: ColumnStatistics, diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 13f3dc6a16c8..5bd8c5539caa 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -179,12 +179,13 @@ impl ExecutionPlan for MockExec { }) .collect(); - let (tx, rx) = tokio::sync::mpsc::channel(2); + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); // task simply sends data in order but in a separate // thread (to ensure the batches are not available without the // DelayedStream yielding). - let join_handle = tokio::task::spawn(async move { + let tx = builder.tx(); + builder.spawn(async move { for batch in data { println!("Sending batch via delayed stream"); if let Err(e) = tx.send(batch).await { @@ -194,11 +195,7 @@ impl ExecutionPlan for MockExec { }); // returned stream simply reads off the rx stream - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + Ok(builder.build()) } fn fmt_as( @@ -307,12 +304,13 @@ impl ExecutionPlan for BarrierExec { ) -> Result { assert!(partition < self.data.len()); - let (tx, rx) = tokio::sync::mpsc::channel(2); + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); let b = self.barrier.clone(); - let join_handle = tokio::task::spawn(async move { + let tx = builder.tx(); + builder.spawn(async move { println!("Partition {partition} waiting on barrier"); b.wait().await; for batch in data { @@ -324,11 +322,7 @@ impl ExecutionPlan for BarrierExec { }); // returned stream simply reads off the rx stream - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + Ok(builder.build()) } fn fmt_as( @@ -648,25 +642,33 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { /// /// This is useful to test panic handling of certain execution plans. #[derive(Debug)] -pub struct PanickingExec { +pub struct PanicingExec { /// Schema that is mocked by this plan. schema: SchemaRef, - /// Number of output partitions. - n_partitions: usize, + /// Number of output partitions. Each partition will produce this + /// many empty output record batches prior to panicing + batches_until_panics: Vec, } -impl PanickingExec { - /// Create new [`PanickingExec`] with a give schema and number of partitions. +impl PanicingExec { + /// Create new [`PanickingExec`] with a give schema and number of + /// partitions, which will each panic immediately. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { Self { schema, - n_partitions, + batches_until_panics: vec![0; n_partitions], } } + + /// Set the number of batches prior to panic for a partition + pub fn with_partition_panic(mut self, partition: usize, count: usize) -> Self { + self.batches_until_panics[partition] = count; + self + } } -impl ExecutionPlan for PanickingExec { +impl ExecutionPlan for PanicingExec { fn as_any(&self) -> &dyn Any { self } @@ -681,7 +683,8 @@ impl ExecutionPlan for PanickingExec { } fn output_partitioning(&self) -> Partitioning { - Partitioning::UnknownPartitioning(self.n_partitions) + let num_partitions = self.batches_until_panics.len(); + Partitioning::UnknownPartitioning(num_partitions) } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -700,11 +703,14 @@ impl ExecutionPlan for PanickingExec { fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin(PanickingStream { + Ok(Box::pin(PanicingStream { + partition, + batches_until_panic: self.batches_until_panics[partition], schema: Arc::clone(&self.schema), + ready: false, })) } @@ -725,25 +731,44 @@ impl ExecutionPlan for PanickingExec { } } -/// A [`RecordBatchStream`] that panics on first poll. +/// A [`RecordBatchStream`] that yields every other batch and panics after `batches_until_panic` batches have been produced #[derive(Debug)] -pub struct PanickingStream { +struct PanicingStream { + /// Which partition was this + partition: usize, + /// How may batches will be produced until panic + batches_until_panic: usize, /// Schema mocked by this stream. schema: SchemaRef, + /// Should we return ready ? + ready: bool, } -impl Stream for PanickingStream { +impl Stream for PanicingStream { type Item = Result; fn poll_next( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, ) -> Poll> { - panic!("PanickingStream did panic") + if self.batches_until_panic > 0 { + if self.ready { + self.batches_until_panic -= 1; + self.ready = false; + let batch = RecordBatch::new_empty(self.schema.clone()); + return Poll::Ready(Some(Ok(batch))); + } else { + self.ready = true; + // get called again + cx.waker().clone().wake(); + return Poll::Pending; + } + } + panic!("PanickingStream did panic: {}", self.partition) } } -impl RecordBatchStream for PanickingStream { +impl RecordBatchStream for PanicingStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } From 8ffc015d3a5979dcbc83fa87071b4ba80c1b5ee9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 31 May 2023 17:59:19 -0400 Subject: [PATCH 04/14] Update docs / cleanup/ --- datafusion/core/src/physical_plan/stream.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 464b72aee955..b89747f38d2c 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -28,6 +28,7 @@ use futures::stream::BoxStream; use futures::{Future, Stream, StreamExt}; use log::debug; use pin_project_lite::pin_project; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio::task::JoinSet; use tokio_stream::wrappers::ReceiverStream; @@ -36,9 +37,13 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; /// Builder for [`RecordBatchReceiverStream`] that propagates errors /// and panic's correctly. +/// +/// [`RecordBatchReceiverStream`] can be used when there are one or +/// more tasks spawned which produce RecordBatches and send them to a +/// single `Receiver`. pub struct RecordBatchReceiverStreamBuilder { - tx: tokio::sync::mpsc::Sender>, - rx: tokio::sync::mpsc::Receiver>, + tx: Sender>, + rx: Receiver>, schema: SchemaRef, join_set: JoinSet<()>, } @@ -57,7 +62,7 @@ impl RecordBatchReceiverStreamBuilder { } /// Get a handle for sending [`RecordBatch`]es to the output - pub fn tx(&self) -> tokio::sync::mpsc::Sender> { + pub fn tx(&self) -> Sender> { self.tx.clone() } @@ -170,8 +175,8 @@ impl RecordBatchReceiverStreamBuilder { } /// Adapter for a tokio [`ReceiverStream`] that implements the -/// [`SendableRecordBatchStream`] interface. Use -/// [`RecordBatchReceiverStreamBuilder`] to construct one. +/// [`SendableRecordBatchStream`] interface and propagates panics and +/// errors. Use [`Self::builder`] to construct one. pub struct RecordBatchReceiverStream { schema: SchemaRef, inner: BoxStream<'static, Result>, From e5c4e030b54e6036b1d7ce92da20046a9bbf50a9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 2 Jun 2023 16:45:29 -0400 Subject: [PATCH 05/14] Apply suggestions from code review Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- datafusion/core/src/physical_plan/analyze.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index 52bf16629b4d..51de8b3ad5c6 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -140,8 +140,7 @@ impl ExecutionPlan for AnalyzeExec { let mut input_stream = builder.build(); let output = async move { let mut total_rows = 0; - while let Some(batch) = input_stream.next().await { - let batch = batch?; + while let Some(batch) = input_stream.next().await.transpose()? { total_rows += batch.num_rows(); } From 3f80690f7354343c16c1b6ae5317f0a2aeaf0438 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 2 Jun 2023 16:47:44 -0400 Subject: [PATCH 06/14] rename to be consistent and not deal with English pecularities --- datafusion/core/src/physical_plan/coalesce_partitions.rs | 4 ++-- datafusion/core/src/physical_plan/stream.rs | 8 ++++---- datafusion/core/src/test/exec.rs | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index ef283b543d46..66700cd9e748 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -177,7 +177,7 @@ mod tests { use crate::physical_plan::{collect, common}; use crate::prelude::SessionContext; use crate::test::exec::{ - assert_strong_count_converges_to_zero, BlockingExec, PanicingExec, + assert_strong_count_converges_to_zero, BlockingExec, PanicExec, }; use crate::test::{self, assert_is_pending}; @@ -239,7 +239,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let panicking_exec = Arc::new(PanicingExec::new(Arc::clone(&schema), 2)); + let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2)); let coalesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(panicking_exec)); diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index b89747f38d2c..c60c7baafc0f 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -304,7 +304,7 @@ mod test { use super::*; use arrow_schema::{DataType, Field, Schema}; - use crate::{execution::context::SessionContext, test::exec::PanicingExec}; + use crate::{execution::context::SessionContext, test::exec::PanicExec}; #[tokio::test] #[should_panic(expected = "PanickingStream did panic")] @@ -313,7 +313,7 @@ mod test { Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); let num_partitions = 10; - let input = PanicingExec::new(schema.clone(), num_partitions); + let input = PanicExec::new(schema.clone(), num_partitions); consume(input).await } @@ -325,7 +325,7 @@ mod test { // make 2 partitions, second panics before the first let num_partitions = 2; - let input = PanicingExec::new(schema.clone(), num_partitions) + let input = PanicExec::new(schema.clone(), num_partitions) .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) @@ -334,7 +334,7 @@ mod test { /// Consumes all the input's partitions into a /// RecordBatchReceiverStream and runs it to completion - async fn consume(input: PanicingExec) { + async fn consume(input: PanicExec) { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 5bd8c5539caa..0311bdc87c7a 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -642,7 +642,7 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { /// /// This is useful to test panic handling of certain execution plans. #[derive(Debug)] -pub struct PanicingExec { +pub struct PanicExec { /// Schema that is mocked by this plan. schema: SchemaRef, @@ -651,7 +651,7 @@ pub struct PanicingExec { batches_until_panics: Vec, } -impl PanicingExec { +impl PanicExec { /// Create new [`PanickingExec`] with a give schema and number of /// partitions, which will each panic immediately. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { @@ -668,7 +668,7 @@ impl PanicingExec { } } -impl ExecutionPlan for PanicingExec { +impl ExecutionPlan for PanicExec { fn as_any(&self) -> &dyn Any { self } From 76270f0adbe49053613a80ccb2cebef12a573df8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 2 Jun 2023 17:27:43 -0400 Subject: [PATCH 07/14] Add a test and comments --- datafusion/core/src/physical_plan/stream.rs | 24 ++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index c60c7baafc0f..c55e2723126e 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -154,8 +154,13 @@ impl RecordBatchReceiverStreamBuilder { // resume on the main thread std::panic::resume_unwind(e.into_panic()); } else { - return Some(Err(DataFusionError::Execution(format!( - "Task error: {e}" + // This should only occur if the task is + // cancelled, which would only occur if + // the JoinSet were aborted, which in turn + // would imply that the receiver has been + // dropped and this code is not running + return Some(Err(DataFusionError::Internal(format!( + "Non Panic Task error: {e}" )))); } } @@ -314,7 +319,7 @@ mod test { let num_partitions = 10; let input = PanicExec::new(schema.clone(), num_partitions); - consume(input).await + consume(input, 10).await } #[tokio::test] @@ -329,12 +334,15 @@ mod test { .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) - consume(input).await + let max_batches = 5; // expect to read every other batch: (0,1,0,1,0,panic) + consume(input, max_batches).await } /// Consumes all the input's partitions into a /// RecordBatchReceiverStream and runs it to completion - async fn consume(input: PanicExec) { + /// + /// panic's if more than max_batches is seen, + async fn consume(input: PanicExec, max_batches: usize) { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -350,8 +358,14 @@ mod test { let mut stream = builder.build(); // drain the stream until it is complete, panic'ing on error + let mut num_batches = 0; while let Some(next) = stream.next().await { next.unwrap(); + num_batches += 1; + assert!( + num_batches < max_batches, + "Got the limit of {num_batches} batches before seeing panic" + ); } } } From 5ae36209f8df0755e2d77d3e9b4a5de8c2bd68fd Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 5 Jun 2023 17:31:12 -0400 Subject: [PATCH 08/14] write test for drop cancel --- datafusion/core/src/physical_plan/stream.rs | 38 ++++++++++++++++++--- datafusion/core/src/test/exec.rs | 3 ++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index c55e2723126e..f31c34163cf9 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -309,13 +309,17 @@ mod test { use super::*; use arrow_schema::{DataType, Field, Schema}; - use crate::{execution::context::SessionContext, test::exec::PanicExec}; + use crate::{execution::context::SessionContext, test::{exec::{PanicExec, BlockingExec, assert_strong_count_converges_to_zero}}}; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])) + } + #[tokio::test] #[should_panic(expected = "PanickingStream did panic")] async fn record_batch_receiver_stream_propagates_panics() { - let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + let schema = schema(); let num_partitions = 10; let input = PanicExec::new(schema.clone(), num_partitions); @@ -325,8 +329,7 @@ mod test { #[tokio::test] #[should_panic(expected = "PanickingStream did panic: 1")] async fn record_batch_receiver_stream_propagates_panics_one() { - let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + let schema = schema(); // make 2 partitions, second panics before the first let num_partitions = 2; @@ -338,6 +341,31 @@ mod test { consume(input, max_batches).await } + #[tokio::test] + async fn record_batch_receiver_stream_drop_cancel() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = schema(); + + // Make an input that will not proceed + let input = BlockingExec::new(schema.clone(), 1); + let refs = input.refs(); + + // Configure a RecordBatchReceiverStream to consume the input + let mut builder = + RecordBatchReceiverStream::builder(schema, 2); + builder.run_input(Arc::new(input), 0, task_ctx.clone()); + + let stream = builder.build(); + + // input stream should be present + assert!(std::sync::Weak::strong_count(&refs) > 0); + + // drop the stream, ensure the refs go to zero + drop(stream); + assert_strong_count_converges_to_zero(refs).await; + } + /// Consumes all the input's partitions into a /// RecordBatchReceiverStream and runs it to completion /// diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 0311bdc87c7a..6b274ad85b87 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -638,6 +638,9 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { .unwrap(); } +/// + + /// Execution plan that emits streams that panics. /// /// This is useful to test panic handling of certain execution plans. From a531afe4f278bfbe6a791f3fa3806384353d6c90 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jun 2023 07:32:31 -0400 Subject: [PATCH 09/14] Add test fpr not driving to completion --- datafusion/core/src/physical_plan/stream.rs | 58 ++++++++++++++++-- datafusion/core/src/test/exec.rs | 66 ++++++++++++++------- 2 files changed, 100 insertions(+), 24 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index f31c34163cf9..83d83d0e8806 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -309,13 +309,17 @@ mod test { use super::*; use arrow_schema::{DataType, Field, Schema}; - use crate::{execution::context::SessionContext, test::{exec::{PanicExec, BlockingExec, assert_strong_count_converges_to_zero}}}; + use crate::{ + execution::context::SessionContext, + test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, + }, + }; fn schema() -> SchemaRef { Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])) } - #[tokio::test] #[should_panic(expected = "PanickingStream did panic")] async fn record_batch_receiver_stream_propagates_panics() { @@ -352,8 +356,7 @@ mod test { let refs = input.refs(); // Configure a RecordBatchReceiverStream to consume the input - let mut builder = - RecordBatchReceiverStream::builder(schema, 2); + let mut builder = RecordBatchReceiverStream::builder(schema, 2); builder.run_input(Arc::new(input), 0, task_ctx.clone()); let stream = builder.build(); @@ -366,6 +369,53 @@ mod test { assert_strong_count_converges_to_zero(refs).await; } + #[tokio::test] + /// Ensure that if an error is received in one stream, the + /// `RecordBatchReceiverStream` stops early and does not drive + /// other streams to completion. + async fn record_batch_receiver_stream_error_does_not_drive_completion() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = schema(); + + // Make an input that will not proceed + let blocking_input = BlockingExec::new(schema.clone(), 1); + let refs = blocking_input.refs(); + + // make an input that will error + let error_stream = MockExec::new( + vec![ + Err(DataFusionError::Execution("Test1".to_string())), + Err(DataFusionError::Execution("Test2".to_string())), + ], + schema.clone(), + ) + .with_use_task(false); + + // Configure a RecordBatchReceiverStream to consume the + // blocking input (which will never advance) and the stream + // that will error. + + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + builder.run_input(Arc::new(blocking_input), 0, task_ctx.clone()); + builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); + let mut stream = builder.build(); + + // first input input should be present + assert!(std::sync::Weak::strong_count(&refs) > 0); + + // get the first result, which should be an error + let first_batch = stream.next().await.unwrap(); + let first_err = first_batch.unwrap_err(); + assert_eq!(first_err.to_string(), "Execution error: Test1"); + + // There should be no more batches produced (should not get the second error) + assert!(stream.next().await.is_none()); + + // And the other inputs should be cleaned up (even before stream is dropped) + assert_strong_count_converges_to_zero(refs).await; + } + /// Consumes all the input's partitions into a /// RecordBatchReceiverStream and runs it to completion /// diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 6b274ad85b87..2ce102114dcc 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -31,7 +31,6 @@ use arrow::{ }; use futures::Stream; -use crate::execution::context::TaskContext; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, @@ -41,6 +40,9 @@ use crate::{ error::{DataFusionError, Result}, physical_plan::stream::RecordBatchReceiverStream, }; +use crate::{ + execution::context::TaskContext, physical_plan::stream::RecordBatchStreamAdapter, +}; /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] @@ -121,6 +123,9 @@ pub struct MockExec { /// the results to send back data: Vec>, schema: SchemaRef, + /// if true (the default), sends data using a separate task to to ensure the + /// batches are not available without this stream yielding first + use_task: bool, } impl MockExec { @@ -129,7 +134,19 @@ impl MockExec { /// immediately (the caller has to actually yield and another task /// must run) to ensure any poll loops are correct. pub fn new(data: Vec>, schema: SchemaRef) -> Self { - Self { data, schema } + Self { + data, + schema, + use_task: false, + } + } + + /// If `use_task` is true (the default) then the batches are sent + /// back using a separate task to ensure the underlying stream is + /// not immediately ready + pub fn with_use_task(mut self, use_task: bool) -> Self { + self.use_task = use_task; + self } } @@ -179,23 +196,30 @@ impl ExecutionPlan for MockExec { }) .collect(); - let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); - - // task simply sends data in order but in a separate - // thread (to ensure the batches are not available without the - // DelayedStream yielding). - let tx = builder.tx(); - builder.spawn(async move { - for batch in data { - println!("Sending batch via delayed stream"); - if let Err(e) = tx.send(batch).await { - println!("ERROR batch via delayed stream: {e}"); + if self.use_task { + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); + // send data in order but in a separate + // thread (to ensure the batches are not available without the + // DelayedStream yielding). + let tx = builder.tx(); + builder.spawn(async move { + for batch in data { + println!("Sending batch via delayed stream"); + if let Err(e) = tx.send(batch).await { + println!("ERROR batch via delayed stream: {e}"); + } } - } - }); - - // returned stream simply reads off the rx stream - Ok(builder.build()) + }); + // returned stream simply reads off the rx stream + Ok(builder.build()) + } else { + // make an input that will error + let stream = futures::stream::iter(data); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } } fn fmt_as( @@ -640,7 +664,6 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { /// - /// Execution plan that emits streams that panics. /// /// This is useful to test panic handling of certain execution plans. @@ -734,7 +757,10 @@ impl ExecutionPlan for PanicExec { } } -/// A [`RecordBatchStream`] that yields every other batch and panics after `batches_until_panic` batches have been produced +/// A [`RecordBatchStream`] that yields every other batch and panics +/// after `batches_until_panic` batches have been produced. +/// +/// Useful for testing the behavior of streams on panic #[derive(Debug)] struct PanicingStream { /// Which partition was this From 56a26eb6b4036db57bebba2a1e237a1f6beba544 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jun 2023 08:41:21 -0400 Subject: [PATCH 10/14] Do not drive all streams to error --- datafusion/core/src/physical_plan/stream.rs | 38 ++++++++++----------- datafusion/core/src/test/exec.rs | 23 +++++++------ 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 83d83d0e8806..31bfaccf1328 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -117,9 +117,13 @@ impl RecordBatchReceiverStreamBuilder { Ok(stream) => stream, }; + // Transfer batches from inner stream to the output tx + // immediately. while let Some(item) = stream.next().await { - // If send fails, plan being torn down, - // there is no place to send the error. + let is_err = item.is_err(); + + // If send fails, plan being torn down, there is no + // place to send the error and no reason to continue. if output.send(item).await.is_err() { debug!( "Stopping execution: output is gone, plan cancelling: {}", @@ -127,6 +131,12 @@ impl RecordBatchReceiverStreamBuilder { ); return; } + + // stop after the first error is encontered (don't + // drive all streams to completion) + if is_err { + return; + } } }); } @@ -332,7 +342,7 @@ mod test { #[tokio::test] #[should_panic(expected = "PanickingStream did panic: 1")] - async fn record_batch_receiver_stream_propagates_panics_one() { + async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { let schema = schema(); // make 2 partitions, second panics before the first @@ -341,7 +351,12 @@ mod test { .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) - let max_batches = 5; // expect to read every other batch: (0,1,0,1,0,panic) + // ensure that the panic results in an early shutdown (that + // everything stops after the first panic). + + // Since the stream reads every other batch: (0,1,0,1,0,panic) + // so should not exceed 5 batches prior to the panic + let max_batches = 5; consume(input, max_batches).await } @@ -378,10 +393,6 @@ mod test { let task_ctx = session_ctx.task_ctx(); let schema = schema(); - // Make an input that will not proceed - let blocking_input = BlockingExec::new(schema.clone(), 1); - let refs = blocking_input.refs(); - // make an input that will error let error_stream = MockExec::new( vec![ @@ -392,18 +403,10 @@ mod test { ) .with_use_task(false); - // Configure a RecordBatchReceiverStream to consume the - // blocking input (which will never advance) and the stream - // that will error. - let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(blocking_input), 0, task_ctx.clone()); builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); let mut stream = builder.build(); - // first input input should be present - assert!(std::sync::Weak::strong_count(&refs) > 0); - // get the first result, which should be an error let first_batch = stream.next().await.unwrap(); let first_err = first_batch.unwrap_err(); @@ -411,9 +414,6 @@ mod test { // There should be no more batches produced (should not get the second error) assert!(stream.next().await.is_none()); - - // And the other inputs should be cleaned up (even before stream is dropped) - assert_strong_count_converges_to_zero(refs).await; } /// Consumes all the input's partitions into a diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 2ce102114dcc..e04e3a7f61f7 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -116,8 +116,8 @@ impl RecordBatchStream for TestStream { } } -/// A Mock ExecutionPlan that can be used for writing tests of other ExecutionPlans -/// +/// A Mock ExecutionPlan that can be used for writing tests of other +/// ExecutionPlans #[derive(Debug)] pub struct MockExec { /// the results to send back @@ -129,15 +129,18 @@ pub struct MockExec { } impl MockExec { - /// Create a new exec with a single partition that returns the - /// record batches in this Exec. Note the batches are not produced - /// immediately (the caller has to actually yield and another task - /// must run) to ensure any poll loops are correct. + /// Create a new `MockExec` with a single partition that returns + /// the specified `Results`s. + /// + /// By default, the batches are not produced immediately (the + /// caller has to actually yield and another task must run) to + /// ensure any poll loops are correct. This behavior can be + /// changed with `with_use_task` pub fn new(data: Vec>, schema: SchemaRef) -> Self { Self { data, schema, - use_task: false, + use_task: true, } } @@ -198,9 +201,9 @@ impl ExecutionPlan for MockExec { if self.use_task { let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); - // send data in order but in a separate - // thread (to ensure the batches are not available without the - // DelayedStream yielding). + // send data in order but in a separate task (to ensure + // the batches are not available without the stream + // yielding). let tx = builder.tx(); builder.spawn(async move { for batch in data { From b1a817ce8e6c4eb210b40030f345ffb246184ac3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jun 2023 09:10:31 -0400 Subject: [PATCH 11/14] terminate early on panic --- datafusion/core/src/physical_plan/stream.rs | 26 +++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 31bfaccf1328..8b6acd847dc1 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -38,9 +38,11 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; /// Builder for [`RecordBatchReceiverStream`] that propagates errors /// and panic's correctly. /// -/// [`RecordBatchReceiverStream`] can be used when there are one or -/// more tasks spawned which produce RecordBatches and send them to a -/// single `Receiver`. +/// [`RecordBatchReceiverStream`] is used to spawn one or more tasks +/// that produce `RecordBatch`es and send them to a single +/// `Receiver` which can improve parallelism. +/// +/// This also handles propagating panic`s and canceling the tasks. pub struct RecordBatchReceiverStreamBuilder { tx: Sender>, rx: Receiver>, @@ -94,6 +96,9 @@ impl RecordBatchReceiverStreamBuilder { /// runs the input_partition of the `input` ExecutionPlan on the /// tokio threadpool and writes its outputs to this stream + /// + /// If the input partition produces an error, the error will be + /// sent to the output stream and no further results are sent. pub(crate) fn run_input( &mut self, input: Arc, @@ -105,8 +110,8 @@ impl RecordBatchReceiverStreamBuilder { self.spawn(async move { let mut stream = match input.execute(partition, context) { Err(e) => { - // If send fails, plan being torn down, - // there is no place to send the error. + // If send fails, the plan being torn down, there + // is no place to send the error and no reason to continue. output.send(Err(e)).await.ok(); debug!( "Stopping execution: error executing input: {}", @@ -135,6 +140,10 @@ impl RecordBatchReceiverStreamBuilder { // stop after the first error is encontered (don't // drive all streams to completion) if is_err { + debug!( + "Stopping execution: plan returned error: {}", + displayable(input.as_ref()).one_line() + ); return; } } @@ -153,7 +162,7 @@ impl RecordBatchReceiverStreamBuilder { // don't need tx drop(tx); - // future that checks the result of the join set + // future that checks the result of the join set, and propagates panic if seen let check = async move { while let Some(result) = join_set.join_next().await { match result { @@ -183,7 +192,10 @@ impl RecordBatchReceiverStreamBuilder { // unwrap Option / only return the error .filter_map(|item| async move { item }); - let inner = ReceiverStream::new(rx).chain(check_stream).boxed(); + // Merge the streams together (but futures::stream:StreamExt + // is already in scope, so call it explicitly) + let inner = + tokio_stream::StreamExt::merge(ReceiverStream::new(rx), check_stream).boxed(); Box::pin(RecordBatchReceiverStream { schema, inner }) } From 79fcbfac3eaf2d509d0389297f484cdc83ad425c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jun 2023 09:15:00 -0400 Subject: [PATCH 12/14] tweak comments --- datafusion/core/src/physical_plan/stream.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 8b6acd847dc1..a31bf6e162f4 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -192,8 +192,9 @@ impl RecordBatchReceiverStreamBuilder { // unwrap Option / only return the error .filter_map(|item| async move { item }); - // Merge the streams together (but futures::stream:StreamExt - // is already in scope, so call it explicitly) + // Merge the streams together so whichever is ready first + // produces the batch (since futures::stream:StreamExt is + // already in scope, need to call it explicitly) let inner = tokio_stream::StreamExt::merge(ReceiverStream::new(rx), check_stream).boxed(); @@ -290,7 +291,7 @@ where } /// Stream wrapper that records `BaselineMetrics` for a particular -/// partition +/// `[SendableRecordBatchStream]` (likely a partition) pub(crate) struct ObservedStream { inner: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, @@ -357,7 +358,7 @@ mod test { async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { let schema = schema(); - // make 2 partitions, second panics before the first + // make 2 partitions, second partition panics before the first let num_partitions = 2; let input = PanicExec::new(schema.clone(), num_partitions) .with_partition_panic(0, 10) @@ -378,17 +379,16 @@ mod test { let task_ctx = session_ctx.task_ctx(); let schema = schema(); - // Make an input that will not proceed + // Make an input that never proceeds let input = BlockingExec::new(schema.clone(), 1); let refs = input.refs(); // Configure a RecordBatchReceiverStream to consume the input let mut builder = RecordBatchReceiverStream::builder(schema, 2); builder.run_input(Arc::new(input), 0, task_ctx.clone()); - let stream = builder.build(); - // input stream should be present + // input should still be present assert!(std::sync::Weak::strong_count(&refs) > 0); // drop the stream, ensure the refs go to zero @@ -405,7 +405,7 @@ mod test { let task_ctx = session_ctx.task_ctx(); let schema = schema(); - // make an input that will error + // make an input that will error twice let error_stream = MockExec::new( vec![ Err(DataFusionError::Execution("Test1".to_string())), From 70a3d573b5ad0067880ed831699760f03e210536 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jun 2023 09:16:21 -0400 Subject: [PATCH 13/14] tweak comments --- datafusion/core/src/test/exec.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index e04e3a7f61f7..41a0a1b4d084 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -735,7 +735,7 @@ impl ExecutionPlan for PanicExec { partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin(PanicingStream { + Ok(Box::pin(PanicStream { partition, batches_until_panic: self.batches_until_panics[partition], schema: Arc::clone(&self.schema), @@ -765,7 +765,7 @@ impl ExecutionPlan for PanicExec { /// /// Useful for testing the behavior of streams on panic #[derive(Debug)] -struct PanicingStream { +struct PanicStream { /// Which partition was this partition: usize, /// How may batches will be produced until panic @@ -776,7 +776,7 @@ struct PanicingStream { ready: bool, } -impl Stream for PanicingStream { +impl Stream for PanicStream { type Item = Result; fn poll_next( @@ -800,7 +800,7 @@ impl Stream for PanicingStream { } } -impl RecordBatchStream for PanicingStream { +impl RecordBatchStream for PanicStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } From fb17af86863addc83e23591d31102687a7e161ca Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Jun 2023 10:36:27 -0400 Subject: [PATCH 14/14] use futures::stream --- datafusion/core/src/physical_plan/stream.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index a31bf6e162f4..75a0f45e1ee2 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -193,10 +193,9 @@ impl RecordBatchReceiverStreamBuilder { .filter_map(|item| async move { item }); // Merge the streams together so whichever is ready first - // produces the batch (since futures::stream:StreamExt is - // already in scope, need to call it explicitly) + // produces the batch let inner = - tokio_stream::StreamExt::merge(ReceiverStream::new(rx), check_stream).boxed(); + futures::stream::select(ReceiverStream::new(rx), check_stream).boxed(); Box::pin(RecordBatchReceiverStream { schema, inner }) }