diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index c5c9f37322..ab0e3900ec 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -115,7 +115,8 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] } utime = "0.3" [features] -default = [] +cdf = [] +default = ["cdf"] datafusion = [ "dep:datafusion", "datafusion-expr", diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index f3f9b5d6cc..68ad67ee44 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -509,6 +509,12 @@ impl<'a> DeltaScanBuilder<'a> { self } + /// Use the provided [SchemaRef] for the [DeltaScan] + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + pub async fn build(self) -> DeltaResult { let config = self.config; let schema = match self.schema { diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs new file mode 100644 index 0000000000..cb593b6b9a --- /dev/null +++ b/crates/core/src/operations/cdc.rs @@ -0,0 +1,316 @@ +//! +//! The CDC module contains private tools for managing CDC files +//! + +use crate::DeltaResult; + +use arrow::array::{Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::error::Result as DataFusionResult; +use datafusion::physical_plan::{ + metrics::MetricsSet, DisplayAs, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, +}; +use datafusion::prelude::*; +use futures::{Stream, StreamExt}; +use std::sync::Arc; +use tokio::sync::mpsc::*; +use tracing::log::*; + +/// Maximum in-memory channel size for the tracker to use +const MAX_CHANNEL_SIZE: usize = 1024; + +/// The CDCTracker is useful for hooking reads/writes in a manner nececessary to create CDC files +/// associated with commits +pub(crate) struct CDCTracker { + schema: SchemaRef, + pre_sender: Sender, + pre_receiver: Receiver, + post_sender: Sender, + post_receiver: Receiver, +} + +impl CDCTracker { + /// construct + pub(crate) fn new(schema: SchemaRef) -> Self { + let (pre_sender, pre_receiver) = channel(MAX_CHANNEL_SIZE); + let (post_sender, post_receiver) = channel(MAX_CHANNEL_SIZE); + Self { + schema, + pre_sender, + pre_receiver, + post_sender, + post_receiver, + } + } + + /// Return an owned [Sender] for the caller to use when sending read but not altered batches + pub(crate) fn pre_sender(&self) -> Sender { + self.pre_sender.clone() + } + + /// Return an owned [Sender][ for the caller to use when sending altered batches + pub(crate) fn post_sender(&self) -> Sender { + self.post_sender.clone() + } + + pub(crate) async fn collect(mut self) -> DeltaResult> { + debug!("Collecting all the batches for diffing"); + let ctx = SessionContext::new(); + let mut pre = vec![]; + let mut post = vec![]; + + while !self.pre_receiver.is_empty() { + if let Ok(batch) = self.pre_receiver.try_recv() { + pre.push(batch); + } else { + warn!("Error when receiving on the pre-receiver"); + } + } + + while !self.post_receiver.is_empty() { + if let Ok(batch) = self.post_receiver.try_recv() { + post.push(batch); + } else { + warn!("Error when receiving on the post-receiver"); + } + } + + // Collect _all_ the batches for consideration + let pre = ctx.read_batches(pre)?; + let post = ctx.read_batches(post)?; + + // There is certainly a better way to do this other than stupidly cloning data for diffing + // purposes, but this is the quickest and easiest way to "diff" the two sets of batches + let preimage = pre.clone().except(post.clone())?; + let postimage = post.except(pre)?; + + // Create a new schema which represents the input batch along with the CDC + // columns + let mut fields: Vec> = self.schema.fields().to_vec().clone(); + fields.push(Arc::new(Field::new("_change_type", DataType::Utf8, true))); + let schema = Arc::new(Schema::new(fields)); + + let mut batches = vec![]; + + let mut pre_stream = preimage.execute_stream().await?; + let mut post_stream = postimage.execute_stream().await?; + + // Fill up on pre image batches + while let Some(Ok(batch)) = pre_stream.next().await { + let batch = crate::operations::cast::cast_record_batch( + &batch, + self.schema.clone(), + true, + false, + )?; + let new_column = Arc::new(StringArray::from(vec![ + Some("update_preimage"); + batch.num_rows() + ])); + let mut columns: Vec> = batch.columns().to_vec(); + columns.push(new_column); + + let batch = RecordBatch::try_new(schema.clone(), columns)?; + batches.push(batch); + } + + // Fill up on the post-image batches + while let Some(Ok(batch)) = post_stream.next().await { + let batch = crate::operations::cast::cast_record_batch( + &batch, + self.schema.clone(), + true, + false, + )?; + let new_column = Arc::new(StringArray::from(vec![ + Some("update_postimage"); + batch.num_rows() + ])); + let mut columns: Vec> = batch.columns().to_vec(); + columns.push(new_column); + + let batch = RecordBatch::try_new(schema.clone(), columns)?; + batches.push(batch); + } + + debug!("Found {} batches to consider `CDC` data", batches.len()); + + // At this point the batches should just contain the changes + Ok(batches) + } +} + +/// A DataFusion observer to help pick up on pre-image changes +pub(crate) struct CDCObserver { + parent: Arc, + id: String, + sender: Sender, +} + +impl CDCObserver { + pub(crate) fn new( + id: String, + sender: Sender, + parent: Arc, + ) -> Self { + Self { id, sender, parent } + } +} + +impl std::fmt::Debug for CDCObserver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CDCObserver").field("id", &self.id).finish() + } +} + +impl DisplayAs for CDCObserver { + fn fmt_as( + &self, + _: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "CDCObserver id={}", self.id) + } +} + +impl ExecutionPlan for CDCObserver { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.parent.schema() + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + self.parent.properties() + } + + fn children(&self) -> Vec> { + vec![self.parent.clone()] + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion_common::Result { + let res = self.parent.execute(partition, context)?; + Ok(Box::pin(CDCObserverStream { + schema: self.schema(), + input: res, + sender: self.sender.clone(), + })) + } + + fn statistics(&self) -> DataFusionResult { + self.parent.statistics() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + if let Some(parent) = children.first() { + Ok(Arc::new(CDCObserver { + id: self.id.clone(), + sender: self.sender.clone(), + parent: parent.clone(), + })) + } else { + Err(datafusion_common::DataFusionError::Internal( + "Failed to handle CDCObserver".into(), + )) + } + } + + fn metrics(&self) -> Option { + self.parent.metrics() + } +} + +/// The CDCObserverStream simply acts to help observe the stream of data being +/// read by DataFusion to capture the pre-image versions of data +pub(crate) struct CDCObserverStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + sender: Sender, +} + +impl Stream for CDCObserverStream { + type Item = DataFusionResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => { + let _ = self.sender.try_send(batch.clone()); + Some(Ok(batch)) + } + other => other, + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +impl RecordBatchStream for CDCObserverStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use datafusion::assert_batches_sorted_eq; + + #[tokio::test] + async fn test_sanity_check() { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + true, + )])); + let tracker = CDCTracker::new(schema.clone()); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + ) + .unwrap(); + let updated_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)]))], + ) + .unwrap(); + + let _ = tracker.pre_sender().send(batch).await; + let _ = tracker.post_sender().send(updated_batch).await; + + match tracker.collect().await { + Ok(batches) => { + let _ = arrow::util::pretty::print_batches(&batches); + assert_eq!(batches.len(), 2); + assert_batches_sorted_eq! {[ + "+-------+------------------+", + "| value | _change_type |", + "+-------+------------------+", + "| 2 | update_preimage |", + "| 12 | update_postimage |", + "+-------+------------------+", + ], &batches } + } + Err(err) => { + println!("err: {err:#?}"); + panic!("Should have never reached this assertion"); + } + } + } +} diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index e20005c69d..29d3ae1f3d 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -164,6 +164,7 @@ async fn excute_non_empty_expr( writer_properties, false, None, + None, ) .await? .into_iter() diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index dd5d433ebd..07496e30ba 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1380,6 +1380,7 @@ async fn execute( writer_properties, safe_cast, None, + None, ) .await?; diff --git a/crates/core/src/operations/mod.rs b/crates/core/src/operations/mod.rs index a2e000ae60..325cd798f1 100644 --- a/crates/core/src/operations/mod.rs +++ b/crates/core/src/operations/mod.rs @@ -38,6 +38,8 @@ use arrow::record_batch::RecordBatch; use optimize::OptimizeBuilder; use restore::RestoreBuilder; +#[cfg(all(feature = "cdf", feature = "datafusion"))] +mod cdc; #[cfg(feature = "datafusion")] pub mod constraints; #[cfg(feature = "datafusion")] diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index 95a0e22d66..4c169c70f3 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -228,6 +228,11 @@ pub static INSTANCE: Lazy = Lazy::new(|| { let mut writer_features = HashSet::new(); writer_features.insert(WriterFeatures::AppendOnly); writer_features.insert(WriterFeatures::TimestampWithoutTimezone); + #[cfg(feature = "cdf")] + { + writer_features.insert(WriterFeatures::ChangeDataFeed); + writer_features.insert(WriterFeatures::GeneratedColumns); + } #[cfg(feature = "datafusion")] { writer_features.insert(WriterFeatures::Invariants); diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 700e23a411..56c6155bb4 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -38,8 +38,10 @@ use datafusion_physical_expr::{ PhysicalExpr, }; use futures::future::BoxFuture; +use object_store::prefix::PrefixStore; use parquet::file::properties::WriterProperties; use serde::Serialize; +use tracing::log::*; use super::transaction::PROTOCOL; use super::write::write_execution_plan; @@ -47,17 +49,27 @@ use super::{ datafusion_utils::Expression, transaction::{CommitBuilder, CommitProperties}, }; -use crate::delta_datafusion::{ - create_physical_expr_fix, expr::fmt_expr_to_sql, physical::MetricObserverExec, - DataFusionMixins, DeltaColumn, DeltaSessionContext, -}; -use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; +use crate::operations::writer::{DeltaWriter, WriterConfig}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; +use crate::{ + delta_datafusion::{ + create_physical_expr_fix, expr::fmt_expr_to_sql, physical::MetricObserverExec, + DataFusionMixins, DeltaColumn, DeltaSessionContext, + }, + operations::cdc::*, +}; +use crate::{ + delta_datafusion::{find_files, register_store, DeltaScanBuilder}, + kernel::AddCDCFile, +}; use crate::{DeltaResult, DeltaTable}; +/// Custom column name used for marking internal [RecordBatch] rows as updated +pub(crate) const UPDATE_PREDICATE_COLNAME: &str = "__delta_rs_update_predicate"; + /// Updates records in the Delta Table. /// See this module's documentation for more information pub struct UpdateBuilder { @@ -222,6 +234,10 @@ async fn execute( let predicate = predicate.unwrap_or(Expr::Literal(ScalarValue::Boolean(Some(true)))); + // Create a projection for a new column with the predicate evaluated + let input_schema = snapshot.input_schema()?; + let tracker = CDCTracker::new(input_schema.clone()); + let execution_props = state.execution_props(); // For each rewrite evaluate the predicate and then modify each expression // to either compute the new value or obtain the old one then write these batches @@ -231,15 +247,23 @@ async fn execute( .await?; let scan = Arc::new(scan); - // Create a projection for a new column with the predicate evaluated - let input_schema = snapshot.input_schema()?; + // Wrap the scan with a CDCObserver if CDC has been abled so that the tracker can + // later be used to produce the CDC files + let scan: Arc = match snapshot.table_config().enable_change_data_feed() { + true => Arc::new(CDCObserver::new( + "cdc-update-observer".into(), + tracker.pre_sender(), + scan.clone(), + )), + false => scan, + }; let mut fields = Vec::new(); for field in input_schema.fields.iter() { fields.push(field.to_owned()); } fields.push(Arc::new(Field::new( - "__delta_rs_update_predicate", + UPDATE_PREDICATE_COLNAME, arrow_schema::DataType::Boolean, true, ))); @@ -265,16 +289,16 @@ async fn execute( when(predicate.clone(), lit(true)).otherwise(lit(ScalarValue::Boolean(None)))?; let predicate_expr = create_physical_expr_fix(predicate_null, &input_dfschema, execution_props)?; - expressions.push((predicate_expr, "__delta_rs_update_predicate".to_string())); + expressions.push((predicate_expr, UPDATE_PREDICATE_COLNAME.to_string())); let projection_predicate: Arc = - Arc::new(ProjectionExec::try_new(expressions, scan)?); + Arc::new(ProjectionExec::try_new(expressions, scan.clone())?); let count_plan = Arc::new(MetricObserverExec::new( "update_count".into(), projection_predicate.clone(), |batch, metrics| { - let array = batch.column_by_name("__delta_rs_update_predicate").unwrap(); + let array = batch.column_by_name(UPDATE_PREDICATE_COLNAME).unwrap(); let copied_rows = array.null_count(); let num_updated = array.len() - copied_rows; @@ -305,10 +329,10 @@ async fn execute( // Maintain a map from the original column name to its temporary column index let mut map = HashMap::::new(); let mut control_columns = HashSet::::new(); - control_columns.insert("__delta_rs_update_predicate".to_owned()); + control_columns.insert(UPDATE_PREDICATE_COLNAME.to_string()); for (column, expr) in updates { - let expr = case(col("__delta_rs_update_predicate")) + let expr = case(col(UPDATE_PREDICATE_COLNAME)) .when(lit(true), expr.to_owned()) .otherwise(col(column.to_owned()))?; let predicate_expr = create_physical_expr_fix(expr, &input_dfschema, execution_props)?; @@ -324,6 +348,7 @@ async fn execute( // Project again to remove __delta_rs columns and rename update columns to their original name let mut expressions: Vec<(Arc, String)> = Vec::new(); let scan_schema = projection_update.schema(); + for (i, field) in scan_schema.fields().into_iter().enumerate() { if !control_columns.contains(field.name()) { match map.get(field.name()) { @@ -356,9 +381,10 @@ async fn execute( log_store.object_store().clone(), Some(snapshot.table_config().target_file_size() as usize), None, - writer_properties, + writer_properties.clone(), safe_cast, None, + Some(tracker.post_sender()), ) .await?; @@ -413,6 +439,48 @@ async fn execute( serde_json::to_value(&metrics)?, ); + match tracker.collect().await { + Ok(batches) => { + if !batches.is_empty() { + debug!( + "Collected {} batches to write as part of this transaction:", + batches.len() + ); + let config = WriterConfig::new( + batches[0].schema().clone(), + snapshot.metadata().partition_columns.clone(), + writer_properties.clone(), + None, + None, + ); + + let store = Arc::new(PrefixStore::new( + log_store.object_store().clone(), + "_change_data", + )); + let mut writer = DeltaWriter::new(store, config); + for batch in batches { + writer.write(&batch).await?; + } + // Add the AddCDCFile actions that exist to the commit + actions.extend(writer.close().await?.into_iter().map(|add| { + Action::Cdc(AddCDCFile { + // This is a gnarly hack, but the action needs the nested path, not the + // path isnide the prefixed store + path: format!("_change_data/{}", add.path), + size: add.size, + partition_values: add.partition_values, + data_change: false, + tags: add.tags, + }) + })); + } + } + Err(err) => { + error!("Failed to collect CDC batches: {err:#?}"); + } + }; + let commit = CommitBuilder::from(commit_properties) .with_actions(actions) .build(Some(&snapshot), log_store, operation)? @@ -463,10 +531,12 @@ impl std::future::IntoFuture for UpdateBuilder { #[cfg(test)] mod tests { + use super::*; + + use crate::delta_datafusion::cdf::DeltaCdfScan; use crate::kernel::DataType as DeltaDataType; - use crate::kernel::PrimitiveType; - use crate::kernel::StructField; - use crate::kernel::StructType; + use crate::kernel::{Action, PrimitiveType, Protocol, StructField, StructType}; + use crate::operations::collect_sendable_stream; use crate::operations::DeltaOps; use crate::writer::test_utils::datafusion::get_data; use crate::writer::test_utils::datafusion::write_batch; @@ -475,12 +545,13 @@ mod tests { }; use crate::DeltaConfigKey; use crate::DeltaTable; + use arrow::array::{Int32Array, StringArray}; use arrow::datatypes::Schema as ArrowSchema; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; - use arrow_array::Int32Array; use arrow_schema::DataType; use datafusion::assert_batches_sorted_eq; + use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use serde_json::json; use std::sync::Arc; @@ -960,4 +1031,242 @@ mod tests { .await; assert!(res.is_err()); } + + #[tokio::test] + async fn test_no_cdc_on_older_tables() { + let table = prepare_values_table().await; + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 1); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + arrow::datatypes::DataType::Int32, + true, + )])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + ) + .unwrap(); + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(2))) + .with_update("value", lit(12)) + .await + .unwrap(); + assert_eq!(table.version(), 2); + + // NOTE: This currently doesn't really assert anything because cdc_files() is not reading + // actions correct + let cdc_files = table.state.clone().unwrap().cdc_files(); + assert!(cdc_files.is_ok()); + assert_eq!(cdc_files.unwrap().len(), 0); + + // Too close for missiles, switching to guns. Just checking that the data wasn't actually + // written instead! + if let Ok(files) = crate::storage::utils::flatten_list_stream( + &table.object_store(), + Some(&object_store::path::Path::from("_change_data")), + ) + .await + { + assert_eq!( + 0, + files.len(), + "This test should not find any written CDC files! {files:#?}" + ); + } + } + + #[tokio::test] + async fn test_update_cdc_enabled() { + // Currently you cannot pass EnableChangeDataFeed through `with_configuration_property` + // so the only way to create a truly CDC enabled table is by shoving the Protocol + // directly into the actions list + let actions = vec![Action::Protocol(Protocol::new(1, 4))]; + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_column( + "value", + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + None, + ) + .with_actions(actions) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + arrow::datatypes::DataType::Int32, + true, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + ) + .unwrap(); + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(2))) + .with_update("value", lit(12)) + .await + .unwrap(); + assert_eq!(table.version(), 2); + + let ctx = SessionContext::new(); + let table = DeltaOps(table) + .load_cdf() + .with_session_ctx(ctx.clone()) + .with_starting_version(0) + .build() + .await + .expect("Failed to load CDF"); + + let mut batches = collect_batches( + table.properties().output_partitioning().partition_count(), + table, + ctx, + ) + .await + .expect("Failed to collect batches"); + + // The batches will contain a current _commit_timestamp which shouldn't be check_append_only + let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(3)).collect(); + + assert_batches_sorted_eq! {[ + "+-------+------------------+-----------------+", + "| value | _change_type | _commit_version |", + "+-------+------------------+-----------------+", + "| 1 | insert | 1 |", + "| 2 | insert | 1 |", + "| 2 | update_preimage | 2 |", + "| 12 | update_postimage | 2 |", + "| 3 | insert | 1 |", + "+-------+------------------+-----------------+", + ], &batches } + } + + #[tokio::test] + async fn test_update_cdc_enabled_partitions() { + //let _ = pretty_env_logger::try_init(); + // Currently you cannot pass EnableChangeDataFeed through `with_configuration_property` + // so the only way to create a truly CDC enabled table is by shoving the Protocol + // directly into the actions list + let actions = vec![Action::Protocol(Protocol::new(1, 4))]; + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_column( + "year", + DeltaDataType::Primitive(PrimitiveType::String), + true, + None, + ) + .with_column( + "value", + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + None, + ) + .with_partition_columns(vec!["year"]) + .with_actions(actions) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(Schema::new(vec![ + Field::new("year", DataType::Utf8, true), + Field::new("value", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020"), + Some("2024"), + ])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + ], + ) + .unwrap(); + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(2))) + .with_update("year", "2024") + .await + .unwrap(); + assert_eq!(table.version(), 2); + + let ctx = SessionContext::new(); + let table = DeltaOps(table) + .load_cdf() + .with_session_ctx(ctx.clone()) + .with_starting_version(0) + .build() + .await + .expect("Failed to load CDF"); + + let mut batches = collect_batches( + table.properties().output_partitioning().partition_count(), + table, + ctx, + ) + .await + .expect("Failed to collect batches"); + + let _ = arrow::util::pretty::print_batches(&batches); + + // The batches will contain a current _commit_timestamp which shouldn't be check_append_only + let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(3)).collect(); + + assert_batches_sorted_eq! {[ + "+-------+------------------+-----------------+------+", + "| value | _change_type | _commit_version | year |", + "+-------+------------------+-----------------+------+", + "| 1 | insert | 1 | 2020 |", + "| 2 | insert | 1 | 2020 |", + "| 2 | update_preimage | 2 | 2020 |", + "| 2 | update_postimage | 2 | 2024 |", + "| 3 | insert | 1 | 2024 |", + "+-------+------------------+-----------------+------+", + ], &batches } + } + + async fn collect_batches( + num_partitions: usize, + stream: DeltaCdfScan, + ctx: SessionContext, + ) -> Result, Box> { + let mut batches = vec![]; + for p in 0..num_partitions { + let data: Vec = + collect_sendable_stream(stream.execute(p, ctx.task_ctx())?).await?; + batches.extend_from_slice(&data); + } + Ok(batches) + } } diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 8ecfb3078b..f657923b90 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -41,6 +41,7 @@ use datafusion_expr::Expr; use futures::future::BoxFuture; use futures::StreamExt; use parquet::file::properties::WriterProperties; +use tracing::log::*; use super::datafusion_utils::Expression; use super::transaction::{CommitBuilder, CommitProperties, TableReference, PROTOCOL}; @@ -63,6 +64,8 @@ use crate::table::Constraint as DeltaConstraint; use crate::writer::record_batch::divide_by_partition_values; use crate::DeltaTable; +use tokio::sync::mpsc::Sender; + #[derive(thiserror::Error, Debug)] enum WriteError { #[error("No data source supplied to write command.")] @@ -351,6 +354,7 @@ async fn write_execution_plan_with_predicate( writer_properties: Option, safe_cast: bool, schema_mode: Option, + sender: Option>, ) -> DeltaResult> { let schema: ArrowSchemaRef = if schema_mode.is_some() { plan.schema() @@ -359,7 +363,6 @@ async fn write_execution_plan_with_predicate( .and_then(|s| s.input_schema().ok()) .unwrap_or(plan.schema()) }; - let checker = if let Some(snapshot) = snapshot { DeltaDataChecker::new(snapshot) } else { @@ -389,11 +392,15 @@ async fn write_execution_plan_with_predicate( ); let mut writer = DeltaWriter::new(object_store.clone(), config); let checker_stream = checker.clone(); + let sender_stream = sender.clone(); let mut stream = inner_plan.execute(i, task_ctx)?; - let handle: tokio::task::JoinHandle>> = - tokio::task::spawn(async move { + + let handle: tokio::task::JoinHandle>> = tokio::task::spawn( + async move { + let sendable = sender_stream.clone(); while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; + checker_stream.check_batch(&batch).await?; let arr = super::cast::cast_record_batch( &batch, @@ -401,6 +408,12 @@ async fn write_execution_plan_with_predicate( safe_cast, schema_mode == Some(SchemaMode::Merge), )?; + + if let Some(s) = sendable.as_ref() { + let _ = s.send(arr.clone()).await; + } else { + debug!("write_execution_plan_with_predicate did not send any batches, no sender."); + } writer.write(&arr).await?; } let add_actions = writer.close().await; @@ -408,7 +421,8 @@ async fn write_execution_plan_with_predicate( Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), Err(err) => Err(err), } - }); + }, + ); tasks.push(handle); } @@ -438,6 +452,7 @@ pub(crate) async fn write_execution_plan( writer_properties: Option, safe_cast: bool, schema_mode: Option, + sender: Option>, ) -> DeltaResult> { write_execution_plan_with_predicate( None, @@ -451,6 +466,7 @@ pub(crate) async fn write_execution_plan( writer_properties, safe_cast, schema_mode, + sender, ) .await } @@ -496,6 +512,7 @@ async fn execute_non_empty_expr( writer_properties, false, None, + None, ) .await?; @@ -736,6 +753,7 @@ impl std::future::IntoFuture for WriteBuilder { this.writer_properties.clone(), this.safe_cast, this.schema_mode, + None, ) .await?; actions.extend(add_actions); @@ -1230,7 +1248,6 @@ mod tests { ], ) .unwrap(); - println!("new_batch: {:?}", new_batch.schema()); let table = DeltaOps(table) .write(vec![new_batch]) .with_save_mode(SaveMode::Append)