From 6a1addb4686ca5559d3ab711dc8a3c2964697abc Mon Sep 17 00:00:00 2001 From: "R. Tyler Croy" Date: Wed, 24 Apr 2024 06:19:48 +0000 Subject: [PATCH] feat: introduce CDC write-side support for the Update operations This change introduces a `CDCTracker` which helps collect changes during merges and update. This is admittedly rather inefficient, but my hope is that this provides a place to start iterating and improving upon the writer code There is still additional work which needs to be done to handle table features properly for other code paths (see the middleware discussion we have had in Slack) but this produces CDC files for Update operations Fixes #604 Fixes #2095 --- crates/core/Cargo.toml | 3 +- crates/core/src/delta_datafusion/mod.rs | 6 + crates/core/src/operations/cdc.rs | 316 +++++++++++++++++ crates/core/src/operations/delete.rs | 1 + crates/core/src/operations/merge/mod.rs | 1 + crates/core/src/operations/mod.rs | 2 + .../src/operations/transaction/protocol.rs | 5 + crates/core/src/operations/update.rs | 334 +++++++++++++++++- crates/core/src/operations/write.rs | 27 +- 9 files changed, 675 insertions(+), 20 deletions(-) create mode 100644 crates/core/src/operations/cdc.rs diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 4e6c552fb0..de5a38b4a7 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -115,7 +115,8 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread"] } utime = "0.3" [features] -default = [] +cdf = [] +default = ["cdf"] datafusion = [ "dep:datafusion", "datafusion-expr", diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index f3f9b5d6cc..68ad67ee44 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -509,6 +509,12 @@ impl<'a> DeltaScanBuilder<'a> { self } + /// Use the provided [SchemaRef] for the [DeltaScan] + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + pub async fn build(self) -> DeltaResult { let config = self.config; let schema = match self.schema { diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs new file mode 100644 index 0000000000..cb593b6b9a --- /dev/null +++ b/crates/core/src/operations/cdc.rs @@ -0,0 +1,316 @@ +//! +//! The CDC module contains private tools for managing CDC files +//! + +use crate::DeltaResult; + +use arrow::array::{Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::error::Result as DataFusionResult; +use datafusion::physical_plan::{ + metrics::MetricsSet, DisplayAs, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, +}; +use datafusion::prelude::*; +use futures::{Stream, StreamExt}; +use std::sync::Arc; +use tokio::sync::mpsc::*; +use tracing::log::*; + +/// Maximum in-memory channel size for the tracker to use +const MAX_CHANNEL_SIZE: usize = 1024; + +/// The CDCTracker is useful for hooking reads/writes in a manner nececessary to create CDC files +/// associated with commits +pub(crate) struct CDCTracker { + schema: SchemaRef, + pre_sender: Sender, + pre_receiver: Receiver, + post_sender: Sender, + post_receiver: Receiver, +} + +impl CDCTracker { + /// construct + pub(crate) fn new(schema: SchemaRef) -> Self { + let (pre_sender, pre_receiver) = channel(MAX_CHANNEL_SIZE); + let (post_sender, post_receiver) = channel(MAX_CHANNEL_SIZE); + Self { + schema, + pre_sender, + pre_receiver, + post_sender, + post_receiver, + } + } + + /// Return an owned [Sender] for the caller to use when sending read but not altered batches + pub(crate) fn pre_sender(&self) -> Sender { + self.pre_sender.clone() + } + + /// Return an owned [Sender][ for the caller to use when sending altered batches + pub(crate) fn post_sender(&self) -> Sender { + self.post_sender.clone() + } + + pub(crate) async fn collect(mut self) -> DeltaResult> { + debug!("Collecting all the batches for diffing"); + let ctx = SessionContext::new(); + let mut pre = vec![]; + let mut post = vec![]; + + while !self.pre_receiver.is_empty() { + if let Ok(batch) = self.pre_receiver.try_recv() { + pre.push(batch); + } else { + warn!("Error when receiving on the pre-receiver"); + } + } + + while !self.post_receiver.is_empty() { + if let Ok(batch) = self.post_receiver.try_recv() { + post.push(batch); + } else { + warn!("Error when receiving on the post-receiver"); + } + } + + // Collect _all_ the batches for consideration + let pre = ctx.read_batches(pre)?; + let post = ctx.read_batches(post)?; + + // There is certainly a better way to do this other than stupidly cloning data for diffing + // purposes, but this is the quickest and easiest way to "diff" the two sets of batches + let preimage = pre.clone().except(post.clone())?; + let postimage = post.except(pre)?; + + // Create a new schema which represents the input batch along with the CDC + // columns + let mut fields: Vec> = self.schema.fields().to_vec().clone(); + fields.push(Arc::new(Field::new("_change_type", DataType::Utf8, true))); + let schema = Arc::new(Schema::new(fields)); + + let mut batches = vec![]; + + let mut pre_stream = preimage.execute_stream().await?; + let mut post_stream = postimage.execute_stream().await?; + + // Fill up on pre image batches + while let Some(Ok(batch)) = pre_stream.next().await { + let batch = crate::operations::cast::cast_record_batch( + &batch, + self.schema.clone(), + true, + false, + )?; + let new_column = Arc::new(StringArray::from(vec![ + Some("update_preimage"); + batch.num_rows() + ])); + let mut columns: Vec> = batch.columns().to_vec(); + columns.push(new_column); + + let batch = RecordBatch::try_new(schema.clone(), columns)?; + batches.push(batch); + } + + // Fill up on the post-image batches + while let Some(Ok(batch)) = post_stream.next().await { + let batch = crate::operations::cast::cast_record_batch( + &batch, + self.schema.clone(), + true, + false, + )?; + let new_column = Arc::new(StringArray::from(vec![ + Some("update_postimage"); + batch.num_rows() + ])); + let mut columns: Vec> = batch.columns().to_vec(); + columns.push(new_column); + + let batch = RecordBatch::try_new(schema.clone(), columns)?; + batches.push(batch); + } + + debug!("Found {} batches to consider `CDC` data", batches.len()); + + // At this point the batches should just contain the changes + Ok(batches) + } +} + +/// A DataFusion observer to help pick up on pre-image changes +pub(crate) struct CDCObserver { + parent: Arc, + id: String, + sender: Sender, +} + +impl CDCObserver { + pub(crate) fn new( + id: String, + sender: Sender, + parent: Arc, + ) -> Self { + Self { id, sender, parent } + } +} + +impl std::fmt::Debug for CDCObserver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CDCObserver").field("id", &self.id).finish() + } +} + +impl DisplayAs for CDCObserver { + fn fmt_as( + &self, + _: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "CDCObserver id={}", self.id) + } +} + +impl ExecutionPlan for CDCObserver { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.parent.schema() + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + self.parent.properties() + } + + fn children(&self) -> Vec> { + vec![self.parent.clone()] + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion_common::Result { + let res = self.parent.execute(partition, context)?; + Ok(Box::pin(CDCObserverStream { + schema: self.schema(), + input: res, + sender: self.sender.clone(), + })) + } + + fn statistics(&self) -> DataFusionResult { + self.parent.statistics() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + if let Some(parent) = children.first() { + Ok(Arc::new(CDCObserver { + id: self.id.clone(), + sender: self.sender.clone(), + parent: parent.clone(), + })) + } else { + Err(datafusion_common::DataFusionError::Internal( + "Failed to handle CDCObserver".into(), + )) + } + } + + fn metrics(&self) -> Option { + self.parent.metrics() + } +} + +/// The CDCObserverStream simply acts to help observe the stream of data being +/// read by DataFusion to capture the pre-image versions of data +pub(crate) struct CDCObserverStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + sender: Sender, +} + +impl Stream for CDCObserverStream { + type Item = DataFusionResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => { + let _ = self.sender.try_send(batch.clone()); + Some(Ok(batch)) + } + other => other, + }) + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +impl RecordBatchStream for CDCObserverStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use datafusion::assert_batches_sorted_eq; + + #[tokio::test] + async fn test_sanity_check() { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + true, + )])); + let tracker = CDCTracker::new(schema.clone()); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + ) + .unwrap(); + let updated_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)]))], + ) + .unwrap(); + + let _ = tracker.pre_sender().send(batch).await; + let _ = tracker.post_sender().send(updated_batch).await; + + match tracker.collect().await { + Ok(batches) => { + let _ = arrow::util::pretty::print_batches(&batches); + assert_eq!(batches.len(), 2); + assert_batches_sorted_eq! {[ + "+-------+------------------+", + "| value | _change_type |", + "+-------+------------------+", + "| 2 | update_preimage |", + "| 12 | update_postimage |", + "+-------+------------------+", + ], &batches } + } + Err(err) => { + println!("err: {err:#?}"); + panic!("Should have never reached this assertion"); + } + } + } +} diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 5a666b924f..cc4e1e3cf9 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -174,6 +174,7 @@ async fn excute_non_empty_expr( false, None, writer_stats_config, + None, ) .await? .into_iter() diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index ce1ee7b223..196a405034 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1389,6 +1389,7 @@ async fn execute( safe_cast, None, writer_stats_config, + None, ) .await?; diff --git a/crates/core/src/operations/mod.rs b/crates/core/src/operations/mod.rs index a2e000ae60..325cd798f1 100644 --- a/crates/core/src/operations/mod.rs +++ b/crates/core/src/operations/mod.rs @@ -38,6 +38,8 @@ use arrow::record_batch::RecordBatch; use optimize::OptimizeBuilder; use restore::RestoreBuilder; +#[cfg(all(feature = "cdf", feature = "datafusion"))] +mod cdc; #[cfg(feature = "datafusion")] pub mod constraints; #[cfg(feature = "datafusion")] diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index 95a0e22d66..4c169c70f3 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -228,6 +228,11 @@ pub static INSTANCE: Lazy = Lazy::new(|| { let mut writer_features = HashSet::new(); writer_features.insert(WriterFeatures::AppendOnly); writer_features.insert(WriterFeatures::TimestampWithoutTimezone); + #[cfg(feature = "cdf")] + { + writer_features.insert(WriterFeatures::ChangeDataFeed); + writer_features.insert(WriterFeatures::GeneratedColumns); + } #[cfg(feature = "datafusion")] { writer_features.insert(WriterFeatures::Invariants); diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 29a1495e47..53b47b8d2b 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -38,8 +38,10 @@ use datafusion_physical_expr::{ PhysicalExpr, }; use futures::future::BoxFuture; +use object_store::prefix::PrefixStore; use parquet::file::properties::WriterProperties; use serde::Serialize; +use tracing::log::*; use super::write::write_execution_plan; use super::{ @@ -52,12 +54,17 @@ use crate::delta_datafusion::{ DataFusionMixins, DeltaColumn, DeltaSessionContext, }; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; -use crate::kernel::{Action, Remove}; +use crate::kernel::{Action, AddCDCFile, Remove}; use crate::logstore::LogStoreRef; +use crate::operations::writer::{DeltaWriter, WriterConfig}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; +use crate::operations::cdc::*; use crate::{DeltaResult, DeltaTable}; +/// Custom column name used for marking internal [RecordBatch] rows as updated +pub(crate) const UPDATE_PREDICATE_COLNAME: &str = "__delta_rs_update_predicate"; + /// Updates records in the Delta Table. /// See this module's documentation for more information pub struct UpdateBuilder { @@ -222,6 +229,10 @@ async fn execute( let predicate = predicate.unwrap_or(Expr::Literal(ScalarValue::Boolean(Some(true)))); + // Create a projection for a new column with the predicate evaluated + let input_schema = snapshot.input_schema()?; + let tracker = CDCTracker::new(input_schema.clone()); + let execution_props = state.execution_props(); // For each rewrite evaluate the predicate and then modify each expression // to either compute the new value or obtain the old one then write these batches @@ -231,15 +242,23 @@ async fn execute( .await?; let scan = Arc::new(scan); - // Create a projection for a new column with the predicate evaluated - let input_schema = snapshot.input_schema()?; + // Wrap the scan with a CDCObserver if CDC has been abled so that the tracker can + // later be used to produce the CDC files + let scan: Arc = match snapshot.table_config().enable_change_data_feed() { + true => Arc::new(CDCObserver::new( + "cdc-update-observer".into(), + tracker.pre_sender(), + scan.clone(), + )), + false => scan, + }; let mut fields = Vec::new(); for field in input_schema.fields.iter() { fields.push(field.to_owned()); } fields.push(Arc::new(Field::new( - "__delta_rs_update_predicate", + UPDATE_PREDICATE_COLNAME, arrow_schema::DataType::Boolean, true, ))); @@ -265,16 +284,16 @@ async fn execute( when(predicate.clone(), lit(true)).otherwise(lit(ScalarValue::Boolean(None)))?; let predicate_expr = create_physical_expr_fix(predicate_null, &input_dfschema, execution_props)?; - expressions.push((predicate_expr, "__delta_rs_update_predicate".to_string())); + expressions.push((predicate_expr, UPDATE_PREDICATE_COLNAME.to_string())); let projection_predicate: Arc = - Arc::new(ProjectionExec::try_new(expressions, scan)?); + Arc::new(ProjectionExec::try_new(expressions, scan.clone())?); let count_plan = Arc::new(MetricObserverExec::new( "update_count".into(), projection_predicate.clone(), |batch, metrics| { - let array = batch.column_by_name("__delta_rs_update_predicate").unwrap(); + let array = batch.column_by_name(UPDATE_PREDICATE_COLNAME).unwrap(); let copied_rows = array.null_count(); let num_updated = array.len() - copied_rows; @@ -305,10 +324,10 @@ async fn execute( // Maintain a map from the original column name to its temporary column index let mut map = HashMap::::new(); let mut control_columns = HashSet::::new(); - control_columns.insert("__delta_rs_update_predicate".to_owned()); + control_columns.insert(UPDATE_PREDICATE_COLNAME.to_string()); for (column, expr) in updates { - let expr = case(col("__delta_rs_update_predicate")) + let expr = case(col(UPDATE_PREDICATE_COLNAME)) .when(lit(true), expr.to_owned()) .otherwise(col(column.to_owned()))?; let predicate_expr = create_physical_expr_fix(expr, &input_dfschema, execution_props)?; @@ -324,6 +343,7 @@ async fn execute( // Project again to remove __delta_rs columns and rename update columns to their original name let mut expressions: Vec<(Arc, String)> = Vec::new(); let scan_schema = projection_update.schema(); + for (i, field) in scan_schema.fields().into_iter().enumerate() { if !control_columns.contains(field.name()) { match map.get(field.name()) { @@ -364,10 +384,11 @@ async fn execute( log_store.object_store().clone(), Some(snapshot.table_config().target_file_size() as usize), None, - writer_properties, + writer_properties.clone(), safe_cast, None, writer_stats_config, + Some(tracker.post_sender()), ) .await?; @@ -422,6 +443,50 @@ async fn execute( serde_json::to_value(&metrics)?, ); + match tracker.collect().await { + Ok(batches) => { + if !batches.is_empty() { + debug!( + "Collected {} batches to write as part of this transaction:", + batches.len() + ); + let config = WriterConfig::new( + batches[0].schema().clone(), + snapshot.metadata().partition_columns.clone(), + writer_properties.clone(), + None, + None, + 0, + None, + ); + + let store = Arc::new(PrefixStore::new( + log_store.object_store().clone(), + "_change_data", + )); + let mut writer = DeltaWriter::new(store, config); + for batch in batches { + writer.write(&batch).await?; + } + // Add the AddCDCFile actions that exist to the commit + actions.extend(writer.close().await?.into_iter().map(|add| { + Action::Cdc(AddCDCFile { + // This is a gnarly hack, but the action needs the nested path, not the + // path isnide the prefixed store + path: format!("_change_data/{}", add.path), + size: add.size, + partition_values: add.partition_values, + data_change: false, + tags: add.tags, + }) + })); + } + } + Err(err) => { + error!("Failed to collect CDC batches: {err:#?}"); + } + }; + let commit = CommitBuilder::from(commit_properties) .with_actions(actions) .build(Some(&snapshot), log_store, operation)? @@ -472,10 +537,12 @@ impl std::future::IntoFuture for UpdateBuilder { #[cfg(test)] mod tests { + use super::*; + + use crate::delta_datafusion::cdf::DeltaCdfScan; use crate::kernel::DataType as DeltaDataType; - use crate::kernel::PrimitiveType; - use crate::kernel::StructField; - use crate::kernel::StructType; + use crate::kernel::{Action, PrimitiveType, Protocol, StructField, StructType}; + use crate::operations::collect_sendable_stream; use crate::operations::DeltaOps; use crate::writer::test_utils::datafusion::get_data; use crate::writer::test_utils::datafusion::write_batch; @@ -484,12 +551,13 @@ mod tests { }; use crate::DeltaConfigKey; use crate::DeltaTable; + use arrow::array::{Int32Array, StringArray}; use arrow::datatypes::Schema as ArrowSchema; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; - use arrow_array::Int32Array; use arrow_schema::DataType; use datafusion::assert_batches_sorted_eq; + use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use serde_json::json; use std::sync::Arc; @@ -969,4 +1037,242 @@ mod tests { .await; assert!(res.is_err()); } + + #[tokio::test] + async fn test_no_cdc_on_older_tables() { + let table = prepare_values_table().await; + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 1); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + arrow::datatypes::DataType::Int32, + true, + )])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + ) + .unwrap(); + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(2))) + .with_update("value", lit(12)) + .await + .unwrap(); + assert_eq!(table.version(), 2); + + // NOTE: This currently doesn't really assert anything because cdc_files() is not reading + // actions correct + let cdc_files = table.state.clone().unwrap().cdc_files(); + assert!(cdc_files.is_ok()); + assert_eq!(cdc_files.unwrap().len(), 0); + + // Too close for missiles, switching to guns. Just checking that the data wasn't actually + // written instead! + if let Ok(files) = crate::storage::utils::flatten_list_stream( + &table.object_store(), + Some(&object_store::path::Path::from("_change_data")), + ) + .await + { + assert_eq!( + 0, + files.len(), + "This test should not find any written CDC files! {files:#?}" + ); + } + } + + #[tokio::test] + async fn test_update_cdc_enabled() { + // Currently you cannot pass EnableChangeDataFeed through `with_configuration_property` + // so the only way to create a truly CDC enabled table is by shoving the Protocol + // directly into the actions list + let actions = vec![Action::Protocol(Protocol::new(1, 4))]; + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_column( + "value", + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + None, + ) + .with_actions(actions) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + arrow::datatypes::DataType::Int32, + true, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + ) + .unwrap(); + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(2))) + .with_update("value", lit(12)) + .await + .unwrap(); + assert_eq!(table.version(), 2); + + let ctx = SessionContext::new(); + let table = DeltaOps(table) + .load_cdf() + .with_session_ctx(ctx.clone()) + .with_starting_version(0) + .build() + .await + .expect("Failed to load CDF"); + + let mut batches = collect_batches( + table.properties().output_partitioning().partition_count(), + table, + ctx, + ) + .await + .expect("Failed to collect batches"); + + // The batches will contain a current _commit_timestamp which shouldn't be check_append_only + let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(3)).collect(); + + assert_batches_sorted_eq! {[ + "+-------+------------------+-----------------+", + "| value | _change_type | _commit_version |", + "+-------+------------------+-----------------+", + "| 1 | insert | 1 |", + "| 2 | insert | 1 |", + "| 2 | update_preimage | 2 |", + "| 12 | update_postimage | 2 |", + "| 3 | insert | 1 |", + "+-------+------------------+-----------------+", + ], &batches } + } + + #[tokio::test] + async fn test_update_cdc_enabled_partitions() { + //let _ = pretty_env_logger::try_init(); + // Currently you cannot pass EnableChangeDataFeed through `with_configuration_property` + // so the only way to create a truly CDC enabled table is by shoving the Protocol + // directly into the actions list + let actions = vec![Action::Protocol(Protocol::new(1, 4))]; + let table: DeltaTable = DeltaOps::new_in_memory() + .create() + .with_column( + "year", + DeltaDataType::Primitive(PrimitiveType::String), + true, + None, + ) + .with_column( + "value", + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + None, + ) + .with_partition_columns(vec!["year"]) + .with_actions(actions) + .with_configuration_property(DeltaConfigKey::EnableChangeDataFeed, Some("true")) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let schema = Arc::new(Schema::new(vec![ + Field::new("year", DataType::Utf8, true), + Field::new("value", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020"), + Some("2024"), + ])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + ], + ) + .unwrap(); + let table = DeltaOps(table) + .write(vec![batch]) + .await + .expect("Failed to write first batch"); + assert_eq!(table.version(), 1); + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate(col("value").eq(lit(2))) + .with_update("year", "2024") + .await + .unwrap(); + assert_eq!(table.version(), 2); + + let ctx = SessionContext::new(); + let table = DeltaOps(table) + .load_cdf() + .with_session_ctx(ctx.clone()) + .with_starting_version(0) + .build() + .await + .expect("Failed to load CDF"); + + let mut batches = collect_batches( + table.properties().output_partitioning().partition_count(), + table, + ctx, + ) + .await + .expect("Failed to collect batches"); + + let _ = arrow::util::pretty::print_batches(&batches); + + // The batches will contain a current _commit_timestamp which shouldn't be check_append_only + let _: Vec<_> = batches.iter_mut().map(|b| b.remove_column(3)).collect(); + + assert_batches_sorted_eq! {[ + "+-------+------------------+-----------------+------+", + "| value | _change_type | _commit_version | year |", + "+-------+------------------+-----------------+------+", + "| 1 | insert | 1 | 2020 |", + "| 2 | insert | 1 | 2020 |", + "| 2 | update_preimage | 2 | 2020 |", + "| 2 | update_postimage | 2 | 2024 |", + "| 3 | insert | 1 | 2024 |", + "+-------+------------------+-----------------+------+", + ], &batches } + } + + async fn collect_batches( + num_partitions: usize, + stream: DeltaCdfScan, + ctx: SessionContext, + ) -> Result, Box> { + let mut batches = vec![]; + for p in 0..num_partitions { + let data: Vec = + collect_sendable_stream(stream.execute(p, ctx.task_ctx())?).await?; + batches.extend_from_slice(&data); + } + Ok(batches) + } } diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index f87037fa16..3594c02d24 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -41,6 +41,7 @@ use datafusion_expr::Expr; use futures::future::BoxFuture; use futures::StreamExt; use parquet::file::properties::WriterProperties; +use tracing::log::*; use super::datafusion_utils::Expression; use super::transaction::{CommitBuilder, CommitProperties, TableReference, PROTOCOL}; @@ -64,6 +65,8 @@ use crate::table::Constraint as DeltaConstraint; use crate::writer::record_batch::divide_by_partition_values; use crate::DeltaTable; +use tokio::sync::mpsc::Sender; + #[derive(thiserror::Error, Debug)] enum WriteError { #[error("No data source supplied to write command.")] @@ -371,6 +374,7 @@ async fn write_execution_plan_with_predicate( safe_cast: bool, schema_mode: Option, writer_stats_config: WriterStatsConfig, + sender: Option>, ) -> DeltaResult> { let schema: ArrowSchemaRef = if schema_mode.is_some() { plan.schema() @@ -379,7 +383,6 @@ async fn write_execution_plan_with_predicate( .and_then(|s| s.input_schema().ok()) .unwrap_or(plan.schema()) }; - let checker = if let Some(snapshot) = snapshot { DeltaDataChecker::new(snapshot) } else { @@ -411,11 +414,15 @@ async fn write_execution_plan_with_predicate( ); let mut writer = DeltaWriter::new(object_store.clone(), config); let checker_stream = checker.clone(); + let sender_stream = sender.clone(); let mut stream = inner_plan.execute(i, task_ctx)?; - let handle: tokio::task::JoinHandle>> = - tokio::task::spawn(async move { + + let handle: tokio::task::JoinHandle>> = tokio::task::spawn( + async move { + let sendable = sender_stream.clone(); while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; + checker_stream.check_batch(&batch).await?; let arr = super::cast::cast_record_batch( &batch, @@ -423,6 +430,12 @@ async fn write_execution_plan_with_predicate( safe_cast, schema_mode == Some(SchemaMode::Merge), )?; + + if let Some(s) = sendable.as_ref() { + let _ = s.send(arr.clone()).await; + } else { + debug!("write_execution_plan_with_predicate did not send any batches, no sender."); + } writer.write(&arr).await?; } let add_actions = writer.close().await; @@ -430,7 +443,8 @@ async fn write_execution_plan_with_predicate( Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), Err(err) => Err(err), } - }); + }, + ); tasks.push(handle); } @@ -461,6 +475,7 @@ pub(crate) async fn write_execution_plan( safe_cast: bool, schema_mode: Option, writer_stats_config: WriterStatsConfig, + sender: Option>, ) -> DeltaResult> { write_execution_plan_with_predicate( None, @@ -475,6 +490,7 @@ pub(crate) async fn write_execution_plan( safe_cast, schema_mode, writer_stats_config, + sender, ) .await } @@ -523,6 +539,7 @@ async fn execute_non_empty_expr( false, None, writer_stats_config, + None, ) .await?; @@ -779,6 +796,7 @@ impl std::future::IntoFuture for WriteBuilder { this.safe_cast, this.schema_mode, writer_stats_config.clone(), + None, ) .await?; actions.extend(add_actions); @@ -1301,7 +1319,6 @@ mod tests { ], ) .unwrap(); - println!("new_batch: {:?}", new_batch.schema()); let table = DeltaOps(table) .write(vec![new_batch]) .with_save_mode(SaveMode::Append)