diff --git a/benchmarks/src/bin/merge.rs b/benchmarks/src/bin/merge.rs new file mode 100644 index 0000000000..bb178a192d --- /dev/null +++ b/benchmarks/src/bin/merge.rs @@ -0,0 +1,655 @@ +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + +use arrow::datatypes::Schema as ArrowSchema; +use arrow_array::{RecordBatch, StringArray, UInt32Array}; +use chrono::Duration; +use clap::{command, Args, Parser, Subcommand}; +use datafusion::{datasource::MemTable, prelude::DataFrame}; +use datafusion_common::DataFusionError; +use datafusion_expr::{cast, col, lit, random}; +use deltalake_core::protocol::SaveMode; +use deltalake_core::{ + arrow::{ + self, + datatypes::{DataType, Field}, + }, + datafusion::prelude::{CsvReadOptions, SessionContext}, + delta_datafusion::{DeltaScanConfig, DeltaTableProvider}, + operations::merge::{MergeBuilder, MergeMetrics}, + DeltaOps, DeltaTable, DeltaTableBuilder, DeltaTableError, ObjectStore, Path, +}; +use serde_json::json; +use tokio::time::Instant; + +/* Convert web_returns dataset from TPC DS's datagen utility into a Delta table + This table will be partitioned on `wr_returned_date_sk` +*/ +pub async fn convert_tpcds_web_returns(input_path: String, table_path: String) -> Result<(), ()> { + let ctx = SessionContext::new(); + + let schema = ArrowSchema::new(vec![ + Field::new("wr_returned_date_sk", DataType::Int64, true), + Field::new("wr_returned_time_sk", DataType::Int64, true), + Field::new("wr_item_sk", DataType::Int64, false), + Field::new("wr_refunded_customer_sk", DataType::Int64, true), + Field::new("wr_refunded_cdemo_sk", DataType::Int64, true), + Field::new("wr_refunded_hdemo_sk", DataType::Int64, true), + Field::new("wr_refunded_addr_sk", DataType::Int64, true), + Field::new("wr_returning_customer_sk", DataType::Int64, true), + Field::new("wr_returning_cdemo_sk", DataType::Int64, true), + Field::new("wr_returning_hdemo_sk", DataType::Int64, true), + Field::new("wr_returning_addr_sk", DataType::Int64, true), + Field::new("wr_web_page_sk", DataType::Int64, true), + Field::new("wr_reason_sk", DataType::Int64, true), + Field::new("wr_order_number", DataType::Int64, false), + Field::new("wr_return_quantity", DataType::Int32, true), + Field::new("wr_return_amt", DataType::Decimal128(7, 2), true), + Field::new("wr_return_tax", DataType::Decimal128(7, 2), true), + Field::new("wr_return_amt_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("wr_fee", DataType::Decimal128(7, 2), true), + Field::new("wr_return_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("wr_refunded_cash", DataType::Decimal128(7, 2), true), + Field::new("wr_reversed_charge", DataType::Decimal128(7, 2), true), + Field::new("wr_account_credit", DataType::Decimal128(7, 2), true), + Field::new("wr_net_loss", DataType::Decimal128(7, 2), true), + ]); + + let table = ctx + .read_csv( + input_path, + CsvReadOptions { + has_header: false, + delimiter: b'|', + file_extension: ".dat", + schema: Some(&schema), + ..Default::default() + }, + ) + .await + .unwrap(); + + DeltaOps::try_from_uri(table_path) + .await + .unwrap() + .write(table.collect().await.unwrap()) + .with_partition_columns(vec!["wr_returned_date_sk"]) + .await + .unwrap(); + + Ok(()) +} + +fn merge_upsert(source: DataFrame, table: DeltaTable) -> Result { + DeltaOps(table) + .merge(source, "source.wr_item_sk = target.wr_item_sk and source.wr_order_number = target.wr_order_number") + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("wr_returned_date_sk", "source.wr_returned_date_sk") + .update("wr_returned_time_sk", "source.wr_returned_time_sk") + .update("wr_item_sk", "source.wr_item_sk") + .update("wr_refunded_customer_sk", "source.wr_refunded_customer_sk") + .update("wr_refunded_cdemo_sk", "source.wr_refunded_cdemo_sk") + .update("wr_refunded_hdemo_sk", "source.wr_refunded_hdemo_sk") + .update("wr_refunded_addr_sk", "source.wr_refunded_addr_sk") + .update("wr_returning_customer_sk", "source.wr_returning_customer_sk") + .update("wr_returning_cdemo_sk", "source.wr_returning_cdemo_sk") + .update("wr_returning_hdemo_sk", "source.wr_returning_hdemo_sk") + .update("wr_returning_addr_sk", "source.wr_returning_addr_sk") + .update("wr_web_page_sk", "source.wr_web_page_sk") + .update("wr_reason_sk", "source.wr_reason_sk") + .update("wr_order_number", "source.wr_order_number") + .update("wr_return_quantity", "source.wr_return_quantity") + .update("wr_return_amt", "source.wr_return_amt") + .update("wr_return_tax", "source.wr_return_tax") + .update("wr_return_amt_inc_tax", "source.wr_return_amt_inc_tax") + .update("wr_fee", "source.wr_fee") + .update("wr_return_ship_cost", "source.wr_return_ship_cost") + .update("wr_refunded_cash", "source.wr_refunded_cash") + .update("wr_reversed_charge", "source.wr_reversed_charge") + .update("wr_account_credit", "source.wr_account_credit") + .update("wr_net_loss", "source.wr_net_loss") + })? + .when_not_matched_insert(|insert| { + insert + .set("wr_returned_date_sk", "source.wr_returned_date_sk") + .set("wr_returned_time_sk", "source.wr_returned_time_sk") + .set("wr_item_sk", "source.wr_item_sk") + .set("wr_refunded_customer_sk", "source.wr_refunded_customer_sk") + .set("wr_refunded_cdemo_sk", "source.wr_refunded_cdemo_sk") + .set("wr_refunded_hdemo_sk", "source.wr_refunded_hdemo_sk") + .set("wr_refunded_addr_sk", "source.wr_refunded_addr_sk") + .set("wr_returning_customer_sk", "source.wr_returning_customer_sk") + .set("wr_returning_cdemo_sk", "source.wr_returning_cdemo_sk") + .set("wr_returning_hdemo_sk", "source.wr_returning_hdemo_sk") + .set("wr_returning_addr_sk", "source.wr_returning_addr_sk") + .set("wr_web_page_sk", "source.wr_web_page_sk") + .set("wr_reason_sk", "source.wr_reason_sk") + .set("wr_order_number", "source.wr_order_number") + .set("wr_return_quantity", "source.wr_return_quantity") + .set("wr_return_amt", "source.wr_return_amt") + .set("wr_return_tax", "source.wr_return_tax") + .set("wr_return_amt_inc_tax", "source.wr_return_amt_inc_tax") + .set("wr_fee", "source.wr_fee") + .set("wr_return_ship_cost", "source.wr_return_ship_cost") + .set("wr_refunded_cash", "source.wr_refunded_cash") + .set("wr_reversed_charge", "source.wr_reversed_charge") + .set("wr_account_credit", "source.wr_account_credit") + .set("wr_net_loss", "source.wr_net_loss") + }) +} + +fn merge_insert(source: DataFrame, table: DeltaTable) -> Result { + DeltaOps(table) + .merge(source, "source.wr_item_sk = target.wr_item_sk and source.wr_order_number = target.wr_order_number") + .with_source_alias("source") + .with_target_alias("target") + .when_not_matched_insert(|insert| { + insert + .set("wr_returned_date_sk", "source.wr_returned_date_sk") + .set("wr_returned_time_sk", "source.wr_returned_time_sk") + .set("wr_item_sk", "source.wr_item_sk") + .set("wr_refunded_customer_sk", "source.wr_refunded_customer_sk") + .set("wr_refunded_cdemo_sk", "source.wr_refunded_cdemo_sk") + .set("wr_refunded_hdemo_sk", "source.wr_refunded_hdemo_sk") + .set("wr_refunded_addr_sk", "source.wr_refunded_addr_sk") + .set("wr_returning_customer_sk", "source.wr_returning_customer_sk") + .set("wr_returning_cdemo_sk", "source.wr_returning_cdemo_sk") + .set("wr_returning_hdemo_sk", "source.wr_returning_hdemo_sk") + .set("wr_returning_addr_sk", "source.wr_returning_addr_sk") + .set("wr_web_page_sk", "source.wr_web_page_sk") + .set("wr_reason_sk", "source.wr_reason_sk") + .set("wr_order_number", "source.wr_order_number") + .set("wr_return_quantity", "source.wr_return_quantity") + .set("wr_return_amt", "source.wr_return_amt") + .set("wr_return_tax", "source.wr_return_tax") + .set("wr_return_amt_inc_tax", "source.wr_return_amt_inc_tax") + .set("wr_fee", "source.wr_fee") + .set("wr_return_ship_cost", "source.wr_return_ship_cost") + .set("wr_refunded_cash", "source.wr_refunded_cash") + .set("wr_reversed_charge", "source.wr_reversed_charge") + .set("wr_account_credit", "source.wr_account_credit") + .set("wr_net_loss", "source.wr_net_loss") + }) +} + +fn merge_delete(source: DataFrame, table: DeltaTable) -> Result { + DeltaOps(table) + .merge(source, "source.wr_item_sk = target.wr_item_sk and source.wr_order_number = target.wr_order_number") + .with_source_alias("source") + .with_target_alias("target") + .when_matched_delete(|delete| { + delete + }) +} + +async fn benchmark_merge_tpcds( + path: String, + parameters: MergePerfParams, + merge: fn(DataFrame, DeltaTable) -> Result, +) -> Result<(core::time::Duration, MergeMetrics), DataFusionError> { + let table = DeltaTableBuilder::from_uri(path).load().await?; + let file_count = table.snapshot()?.files_count(); + + let provider = DeltaTableProvider::try_new( + table.snapshot()?.clone(), + table.log_store(), + DeltaScanConfig { + file_column_name: Some("file_path".to_string()), + ..Default::default() + }, + ) + .unwrap(); + + let ctx = SessionContext::new(); + ctx.register_table("t1", Arc::new(provider))?; + + let files = ctx + .sql("select file_path as file from t1 group by file") + .await? + .with_column("r", random())? + .filter(col("r").lt_eq(lit(parameters.sample_files)))?; + + let file_sample = files.collect_partitioned().await?; + let schema = file_sample.first().unwrap().first().unwrap().schema(); + let mem_table = Arc::new(MemTable::try_new(schema, file_sample)?); + ctx.register_table("file_sample", mem_table)?; + let file_sample_count = ctx.table("file_sample").await?.count().await?; + + let row_sample = ctx.table("t1").await?.join( + ctx.table("file_sample").await?, + datafusion_common::JoinType::Inner, + &["file_path"], + &["file"], + None, + )?; + + let matched = row_sample + .clone() + .filter(random().lt_eq(lit(parameters.sample_matched_rows)))?; + + let rand = cast(random() * lit(u32::MAX), DataType::Int64); + let not_matched = row_sample + .filter(random().lt_eq(lit(parameters.sample_not_matched_rows)))? + .with_column("wr_item_sk", rand.clone())? + .with_column("wr_order_number", rand)?; + + let source = matched.union(not_matched)?; + + let start = Instant::now(); + let (table, metrics) = merge(source, table)?.await?; + let end = Instant::now(); + + let duration = end.duration_since(start); + + println!("Total File count: {}", file_count); + println!("File sample count: {}", file_sample_count); + println!("{:?}", metrics); + println!("Seconds: {}", duration.as_secs_f32()); + + // Clean up and restore to original state. + let (table, _) = DeltaOps(table).restore().with_version_to_restore(0).await?; + let (table, _) = DeltaOps(table) + .vacuum() + .with_retention_period(Duration::seconds(0)) + .with_enforce_retention_duration(false) + .await?; + table + .object_store() + .delete(&Path::parse("_delta_log/00000000000000000001.json")?) + .await?; + table + .object_store() + .delete(&Path::parse("_delta_log/00000000000000000002.json")?) + .await?; + table + .object_store() + .delete(&Path::parse("_delta_log/00000000000000000003.json")?) + .await?; + let _ = table + .object_store() + .delete(&Path::parse("_delta_log/00000000000000000004.json")?) + .await; + + Ok((duration, metrics)) +} + +#[derive(Subcommand, Debug)] +enum Command { + Convert(Convert), + Bench(BenchArg), + Standard(Standard), + Compare(Compare), + Show(Show), +} + +#[derive(Debug, Args)] +struct Convert { + tpcds_path: String, + delta_path: String, +} + +#[derive(Debug, Args)] +struct Standard { + delta_path: String, + samples: Option, + output_path: Option, + group_id: Option, +} + +#[derive(Debug, Args)] +struct Compare { + before_path: String, + before_group_id: String, + after_path: String, + after_group_id: String, +} + +#[derive(Debug, Args)] +struct Show { + path: String, +} + +#[derive(Debug, Args)] +struct BenchArg { + table_path: String, + #[command(subcommand)] + name: MergeBench, +} + +struct Bench { + name: String, + op: fn(DataFrame, DeltaTable) -> Result, + params: MergePerfParams, +} + +impl Bench { + fn new( + name: S, + op: fn(DataFrame, DeltaTable) -> Result, + params: MergePerfParams, + ) -> Self { + Bench { + name: name.to_string(), + op, + params, + } + } +} + +#[derive(Debug, Args, Clone)] +struct MergePerfParams { + pub sample_files: f32, + pub sample_matched_rows: f32, + pub sample_not_matched_rows: f32, +} + +#[derive(Debug, Clone, Subcommand)] +enum MergeBench { + Upsert(MergePerfParams), + Delete(MergePerfParams), + Insert(MergePerfParams), +} + +#[derive(Parser, Debug)] +#[command(about)] +struct MergePrefArgs { + #[command(subcommand)] + command: Command, +} + +#[tokio::main] +async fn main() { + type MergeOp = fn(DataFrame, DeltaTable) -> Result; + match MergePrefArgs::parse().command { + Command::Convert(Convert { + tpcds_path, + delta_path, + }) => { + convert_tpcds_web_returns(tpcds_path, delta_path) + .await + .unwrap(); + } + + Command::Bench(BenchArg { table_path, name }) => { + let (merge_op, params): (MergeOp, MergePerfParams) = match name { + MergeBench::Upsert(params) => (merge_upsert, params), + MergeBench::Delete(params) => (merge_delete, params), + MergeBench::Insert(params) => (merge_insert, params), + }; + + benchmark_merge_tpcds(table_path, params, merge_op) + .await + .unwrap(); + } + Command::Standard(Standard { + delta_path, + samples, + output_path, + group_id, + }) => { + let benches = vec![Bench::new( + "delete_only_fileMatchedFraction_0.05_rowMatchedFraction_0.05", + merge_delete, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.05, + sample_not_matched_rows: 0.0, + }, + ), + Bench::new( + "multiple_insert_only_fileMatchedFraction_0.05_rowNotMatchedFraction_0.05", + merge_insert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.00, + sample_not_matched_rows: 0.05, + }, + ), + Bench::new( + "multiple_insert_only_fileMatchedFraction_0.05_rowNotMatchedFraction_0.50", + merge_insert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.00, + sample_not_matched_rows: 0.50, + }, + ), + Bench::new( + "multiple_insert_only_fileMatchedFraction_0.05_rowNotMatchedFraction_1.0", + merge_insert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.00, + sample_not_matched_rows: 1.0, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_0.01_rowNotMatchedFraction_0.1", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.01, + sample_not_matched_rows: 0.1, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_0.0_rowNotMatchedFraction_0.1", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.00, + sample_not_matched_rows: 0.1, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_0.1_rowNotMatchedFraction_0.0", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.1, + sample_not_matched_rows: 0.0, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_0.1_rowNotMatchedFraction_0.01", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.1, + sample_not_matched_rows: 0.01, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_0.5_rowNotMatchedFraction_0.001", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.5, + sample_not_matched_rows: 0.001, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_0.99_rowNotMatchedFraction_0.001", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 0.99, + sample_not_matched_rows: 0.001, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.05_rowMatchedFraction_1.0_rowNotMatchedFraction_0.001", + merge_upsert, + MergePerfParams { + sample_files: 0.05, + sample_matched_rows: 1.0, + sample_not_matched_rows: 0.001, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_0.5_rowMatchedFraction_0.001_rowNotMatchedFraction_0.001", + merge_upsert, + MergePerfParams { + sample_files: 0.5, + sample_matched_rows: 0.001, + sample_not_matched_rows: 0.001, + }, + ), + Bench::new( + "upsert_fileMatchedFraction_1.0_rowMatchedFraction_0.001_rowNotMatchedFraction_0.001", + merge_upsert, + MergePerfParams { + sample_files: 1.0, + sample_matched_rows: 0.001, + sample_not_matched_rows: 0.001, + }, + ) + ]; + + let num_samples = samples.unwrap_or(1); + let group_id = group_id.unwrap_or( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() + .to_string(), + ); + let output = output_path.unwrap_or("data/benchmarks".into()); + + let mut group_ids = vec![]; + let mut name = vec![]; + let mut samples = vec![]; + let mut duration_ms = vec![]; + let mut data = vec![]; + + for bench in benches { + for sample in 0..num_samples { + println!("Test: {} Sample: {}", bench.name, sample); + let res = + benchmark_merge_tpcds(delta_path.clone(), bench.params.clone(), bench.op) + .await + .unwrap(); + + group_ids.push(group_id.clone()); + name.push(bench.name.clone()); + samples.push(sample); + duration_ms.push(res.0.as_millis() as u32); + data.push(json!(res.1).to_string()); + } + } + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("group_id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("sample", DataType::UInt32, false), + Field::new("duration_ms", DataType::UInt32, false), + Field::new("data", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(StringArray::from(group_ids)), + Arc::new(StringArray::from(name)), + Arc::new(UInt32Array::from(samples)), + Arc::new(UInt32Array::from(duration_ms)), + Arc::new(StringArray::from(data)), + ], + ) + .unwrap(); + + DeltaOps::try_from_uri(output) + .await + .unwrap() + .write(vec![batch]) + .with_save_mode(SaveMode::Append) + .await + .unwrap(); + } + Command::Compare(Compare { + before_path, + before_group_id, + after_path, + after_group_id, + }) => { + let before_table = DeltaTableBuilder::from_uri(before_path) + .load() + .await + .unwrap(); + let after_table = DeltaTableBuilder::from_uri(after_path) + .load() + .await + .unwrap(); + + let ctx = SessionContext::new(); + ctx.register_table("before", Arc::new(before_table)) + .unwrap(); + ctx.register_table("after", Arc::new(after_table)).unwrap(); + + let before_stats = ctx + .sql(&format!( + " + select name as before_name, + avg(cast(duration_ms as float)) as before_duration_avg + from before where group_id = {} + group by name + ", + before_group_id + )) + .await + .unwrap(); + + let after_stats = ctx + .sql(&format!( + " + select name as after_name, + avg(cast(duration_ms as float)) as after_duration_avg + from after where group_id = {} + group by name + ", + after_group_id + )) + .await + .unwrap(); + + before_stats + .join( + after_stats, + datafusion_common::JoinType::Inner, + &["before_name"], + &["after_name"], + None, + ) + .unwrap() + .select(vec![ + col("before_name").alias("name"), + col("before_duration_avg"), + col("after_duration_avg"), + (col("before_duration_avg") / (col("after_duration_avg"))), + ]) + .unwrap() + .sort(vec![col("name").sort(true, true)]) + .unwrap() + .show() + .await + .unwrap(); + } + Command::Show(Show { path }) => { + let stats = DeltaTableBuilder::from_uri(path).load().await.unwrap(); + let ctx = SessionContext::new(); + ctx.register_table("stats", Arc::new(stats)).unwrap(); + + ctx.sql("select * from stats") + .await + .unwrap() + .show() + .await + .unwrap(); + } + } +} diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index 61a58b98f3..407df8db2c 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -671,6 +671,26 @@ pub struct Txn { pub last_updated: Option, } +impl Txn { + /// Create a new application transactions. See [`Txn`] for details. + pub fn new(app_id: &dyn ToString, version: i64) -> Self { + Self::new_with_last_update(app_id, version, None) + } + + /// Create a new application transactions. See [`Txn`] for details. + pub fn new_with_last_update( + app_id: &dyn ToString, + version: i64, + last_updated: Option, + ) -> Self { + Txn { + app_id: app_id.to_string(), + version, + last_updated, + } + } +} + /// The commitInfo is a fairly flexible action within the delta specification, where arbitrary data can be stored. /// However the reference implementation as well as delta-rs store useful information that may for instance /// allow us to be more permissive in commit conflict resolution. diff --git a/crates/core/src/kernel/snapshot/log_segment.rs b/crates/core/src/kernel/snapshot/log_segment.rs index 76bab4b838..34baac5f05 100644 --- a/crates/core/src/kernel/snapshot/log_segment.rs +++ b/crates/core/src/kernel/snapshot/log_segment.rs @@ -30,9 +30,12 @@ lazy_static! { pub(super) static ref COMMIT_SCHEMA: StructType = StructType::new(vec![ ActionType::Add.schema_field().clone(), ActionType::Remove.schema_field().clone(), + ActionType::Txn.schema_field().clone(), + ]); + pub(super) static ref CHECKPOINT_SCHEMA: StructType = StructType::new(vec![ + ActionType::Add.schema_field().clone(), + ActionType::Txn.schema_field().clone(), ]); - pub(super) static ref CHECKPOINT_SCHEMA: StructType = - StructType::new(vec![ActionType::Add.schema_field().clone(),]); pub(super) static ref TOMBSTONE_SCHEMA: StructType = StructType::new(vec![ActionType::Remove.schema_field().clone(),]); } diff --git a/crates/core/src/kernel/snapshot/mod.rs b/crates/core/src/kernel/snapshot/mod.rs index 90fa112cd5..b36be5f9a7 100644 --- a/crates/core/src/kernel/snapshot/mod.rs +++ b/crates/core/src/kernel/snapshot/mod.rs @@ -197,10 +197,11 @@ impl Snapshot { } /// Get the files in the snapshot - pub fn files( + pub fn files<'a>( &self, store: Arc, - ) -> DeltaResult>>> { + visitors: Vec<&'a mut dyn ReplayVisitor>, + ) -> DeltaResult>>> { let log_stream = self.log_segment.commit_stream( store.clone(), &log_segment::COMMIT_SCHEMA, @@ -211,7 +212,7 @@ impl Snapshot { &log_segment::CHECKPOINT_SCHEMA, &self.config, ); - ReplayStream::try_new(log_stream, checkpoint_stream, self) + ReplayStream::try_new(log_stream, checkpoint_stream, self, visitors) } /// Get the commit infos in the snapshot @@ -331,6 +332,12 @@ impl Snapshot { } } +/// Allows hooking into the reading of commit files and checkpoints whenever a table is loaded or updated. +pub trait ReplayVisitor: Send { + /// Process a batch + fn visit_batch(&mut self, batch: &RecordBatch) -> DeltaResult<()>; +} + /// A snapshot of a Delta table that has been eagerly loaded into memory. #[derive(Debug, Clone, PartialEq)] pub struct EagerSnapshot { @@ -347,9 +354,20 @@ impl EagerSnapshot { store: Arc, config: DeltaTableConfig, version: Option, + ) -> DeltaResult { + Self::try_new_with_visitor(table_root, store, config, version, vec![]).await + } + + /// Create a new [`EagerSnapshot`] instance + pub async fn try_new_with_visitor( + table_root: &Path, + store: Arc, + config: DeltaTableConfig, + version: Option, + visitors: Vec<&mut dyn ReplayVisitor>, ) -> DeltaResult { let snapshot = Snapshot::try_new(table_root, store.clone(), config, version).await?; - let files = snapshot.files(store)?.try_collect().await?; + let files = snapshot.files(store, visitors)?.try_collect().await?; Ok(Self { snapshot, files }) } @@ -368,14 +386,16 @@ impl EagerSnapshot { } /// Update the snapshot to the given version - pub async fn update( + pub async fn update<'a>( &mut self, log_store: Arc, target_version: Option, + visitors: Vec<&'a mut dyn ReplayVisitor>, ) -> DeltaResult<()> { if Some(self.version()) == target_version { return Ok(()); } + let new_slice = self .snapshot .update_inner(log_store.clone(), target_version) @@ -399,10 +419,11 @@ impl EagerSnapshot { .boxed() }; let mapper = LogMapper::try_new(&self.snapshot)?; - let files = ReplayStream::try_new(log_stream, checkpoint_stream, &self.snapshot)? - .map(|batch| batch.and_then(|b| mapper.map_batch(b))) - .try_collect() - .await?; + let files = + ReplayStream::try_new(log_stream, checkpoint_stream, &self.snapshot, visitors)? + .map(|batch| batch.and_then(|b| mapper.map_batch(b))) + .try_collect() + .await?; self.files = files; } @@ -476,6 +497,7 @@ impl EagerSnapshot { pub fn advance<'a>( &mut self, commits: impl IntoIterator, + mut visitors: Vec<&'a mut dyn ReplayVisitor>, ) -> DeltaResult { let mut metadata = None; let mut protocol = None; @@ -506,7 +528,11 @@ impl EagerSnapshot { let mut scanner = LogReplayScanner::new(); for batch in actions { - files.push(scanner.process_files_batch(&batch?, true)?); + let batch = batch?; + files.push(scanner.process_files_batch(&batch, true)?); + for visitor in &mut visitors { + visitor.visit_batch(&batch)?; + } } let mapper = LogMapper::try_new(&self.snapshot)?; @@ -652,7 +678,7 @@ mod tests { assert_eq!(tombstones.len(), 31); let batches = snapshot - .files(store.clone())? + .files(store.clone(), vec![])? .try_collect::>() .await?; let expected = [ @@ -778,9 +804,10 @@ mod tests { predicate: None, }; - let actions = vec![CommitData::new(removes, operation, HashMap::new()).unwrap()]; + let actions = + vec![CommitData::new(removes, operation, HashMap::new(), Vec::new()).unwrap()]; - let new_version = snapshot.advance(&actions)?; + let new_version = snapshot.advance(&actions, vec![])?; assert_eq!(new_version, version + 1); let new_files = snapshot.file_actions()?.map(|f| f.path).collect::>(); diff --git a/crates/core/src/kernel/snapshot/replay.rs b/crates/core/src/kernel/snapshot/replay.rs index 61cdab4c09..fc95928d71 100644 --- a/crates/core/src/kernel/snapshot/replay.rs +++ b/crates/core/src/kernel/snapshot/replay.rs @@ -23,14 +23,17 @@ use crate::kernel::arrow::extract::{self as ex, ProvidesColumnByName}; use crate::kernel::arrow::json; use crate::{DeltaResult, DeltaTableConfig, DeltaTableError}; +use super::ReplayVisitor; use super::Snapshot; pin_project! { - pub struct ReplayStream { + pub struct ReplayStream<'a, S> { scanner: LogReplayScanner, mapper: Arc, + visitors: Vec<&'a mut dyn ReplayVisitor>, + #[pin] commits: S, @@ -39,8 +42,13 @@ pin_project! { } } -impl ReplayStream { - pub(super) fn try_new(commits: S, checkpoint: S, snapshot: &Snapshot) -> DeltaResult { +impl<'a, S> ReplayStream<'a, S> { + pub(super) fn try_new( + commits: S, + checkpoint: S, + snapshot: &Snapshot, + visitors: Vec<&'a mut dyn ReplayVisitor>, + ) -> DeltaResult { let stats_schema = Arc::new((&snapshot.stats_schema()?).try_into()?); let mapper = Arc::new(LogMapper { stats_schema, @@ -50,6 +58,7 @@ impl ReplayStream { commits, checkpoint, mapper, + visitors, scanner: LogReplayScanner::new(), }) } @@ -127,7 +136,7 @@ fn map_batch( Ok(batch) } -impl Stream for ReplayStream +impl<'a, S> Stream for ReplayStream<'a, S> where S: Stream>, { @@ -136,19 +145,33 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); let res = this.commits.poll_next(cx).map(|b| match b { - Some(Ok(batch)) => match this.scanner.process_files_batch(&batch, true) { - Ok(filtered) => Some(this.mapper.map_batch(filtered)), - Err(e) => Some(Err(e)), - }, + Some(Ok(batch)) => { + for visitor in this.visitors.iter_mut() { + if let Err(e) = visitor.visit_batch(&batch) { + return Some(Err(e)); + } + } + match this.scanner.process_files_batch(&batch, true) { + Ok(filtered) => Some(this.mapper.map_batch(filtered)), + Err(e) => Some(Err(e)), + } + } Some(Err(e)) => Some(Err(e)), None => None, }); if matches!(res, Poll::Ready(None)) { this.checkpoint.poll_next(cx).map(|b| match b { - Some(Ok(batch)) => match this.scanner.process_files_batch(&batch, false) { - Ok(filtered) => Some(this.mapper.map_batch(filtered)), - Err(e) => Some(Err(e)), - }, + Some(Ok(batch)) => { + for visitor in this.visitors.iter_mut() { + if let Err(e) = visitor.visit_batch(&batch) { + return Some(Err(e)); + } + } + match this.scanner.process_files_batch(&batch, false) { + Ok(filtered) => Some(this.mapper.map_batch(filtered)), + Err(e) => Some(Err(e)), + } + } Some(Err(e)) => Some(Err(e)), None => None, }) diff --git a/crates/core/src/operations/transaction/application.rs b/crates/core/src/operations/transaction/application.rs new file mode 100644 index 0000000000..81f6fb49dc --- /dev/null +++ b/crates/core/src/operations/transaction/application.rs @@ -0,0 +1,131 @@ +#[cfg(test)] +mod tests { + use crate::{ + checkpoints, kernel::Txn, operations::transaction::CommitProperties, protocol::SaveMode, + writer::test_utils::get_record_batch, DeltaOps, DeltaTableBuilder, + }; + + #[tokio::test] + async fn test_app_txn_workload() { + // Test that the transaction ids can be read from different scenarios + // 1. Write new table to storage + // 2. Read new table + // 3. Write to table a new txn id and then update a different table state that uses the same underlying table + // 4. Write a checkpoint and read that checkpoint. + + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = std::fs::canonicalize(tmp_dir.path()).unwrap(); + + let batch = get_record_batch(None, false); + let table = DeltaOps::try_from_uri(tmp_path.to_str().unwrap()) + .await + .unwrap() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .with_partition_columns(["modified"]) + .with_commit_properties( + CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 1)), + ) + .await + .unwrap(); + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 2); + + let app_txns = table.get_app_transaction_version(); + assert_eq!(app_txns.len(), 1); + assert_eq!(app_txns.get("my-app"), Some(&1)); + + // Test Txn Id can be read from existing table + + let mut table2 = DeltaTableBuilder::from_uri(tmp_path.to_str().unwrap()) + .load() + .await + .unwrap(); + let app_txns2 = table2.get_app_transaction_version(); + + assert_eq!(app_txns2.len(), 1); + assert_eq!(app_txns2.get("my-app"), Some(&1)); + + // Write new data to the table and check that `update` functions work + + let table = DeltaOps::from(table) + .write(vec![get_record_batch(None, false)]) + .with_commit_properties( + CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 3)), + ) + .await + .unwrap(); + + assert_eq!(table.version(), 1); + let app_txns = table.get_app_transaction_version(); + assert_eq!(app_txns.len(), 1); + assert_eq!(app_txns.get("my-app"), Some(&3)); + + table2.update_incremental(None).await.unwrap(); + assert_eq!(table2.version(), 1); + let app_txns2 = table2.get_app_transaction_version(); + assert_eq!(app_txns2.len(), 1); + assert_eq!(app_txns2.get("my-app"), Some(&3)); + + // Create a checkpoint and then load + checkpoints::create_checkpoint(&table).await.unwrap(); + let table3 = DeltaTableBuilder::from_uri(tmp_path.to_str().unwrap()) + .load() + .await + .unwrap(); + let app_txns3 = table2.get_app_transaction_version(); + assert_eq!(app_txns3.len(), 1); + assert_eq!(app_txns3.get("my-app"), Some(&3)); + assert_eq!(table3.version(), 1); + } + + #[tokio::test] + async fn test_app_txn_conflict() { + // A conflict must be raised whenever the same application id is used for two concurrent transactions + + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = std::fs::canonicalize(tmp_dir.path()).unwrap(); + + let batch = get_record_batch(None, false); + let table = DeltaOps::try_from_uri(tmp_path.to_str().unwrap()) + .await + .unwrap() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .with_partition_columns(["modified"]) + .with_commit_properties( + CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 1)), + ) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let table2 = DeltaTableBuilder::from_uri(tmp_path.to_str().unwrap()) + .load() + .await + .unwrap(); + assert_eq!(table2.version(), 0); + + let table = DeltaOps::from(table) + .write(vec![get_record_batch(None, false)]) + .with_commit_properties( + CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 2)), + ) + .await + .unwrap(); + assert_eq!(table.version(), 1); + + let res = DeltaOps::from(table2) + .write(vec![get_record_batch(None, false)]) + .with_commit_properties( + CommitProperties::default().with_application_transaction(Txn::new(&"my-app", 3)), + ) + .await; + + let err = res.err().unwrap(); + assert_eq!( + err.to_string(), + "Transaction failed: Failed to commit transaction: Concurrent transaction failed." + ); + } +} diff --git a/crates/core/src/operations/transaction/conflict_checker.rs b/crates/core/src/operations/transaction/conflict_checker.rs index 9463a154b7..ef172abf6f 100644 --- a/crates/core/src/operations/transaction/conflict_checker.rs +++ b/crates/core/src/operations/transaction/conflict_checker.rs @@ -6,6 +6,7 @@ use super::CommitInfo; use crate::delta_datafusion::DataFusionMixins; use crate::errors::DeltaResult; use crate::kernel::EagerSnapshot; +use crate::kernel::Txn; use crate::kernel::{Action, Add, Metadata, Protocol, Remove}; use crate::logstore::{get_actions, LogStore}; use crate::protocol::DeltaOperation; @@ -121,10 +122,18 @@ impl<'a> TransactionInfo<'a> { let read_predicates = read_predicates .map(|pred| read_snapshot.parse_predicate_expression(pred, &session.state())) .transpose()?; + + let mut read_app_ids = HashSet::::new(); + for action in actions.iter() { + if let Action::Txn(Txn { app_id, .. }) = action { + read_app_ids.insert(app_id.clone()); + } + } + Ok(Self { txn_id: "".into(), read_predicates, - read_app_ids: Default::default(), + read_app_ids, actions, read_snapshot, read_whole_table, @@ -139,6 +148,12 @@ impl<'a> TransactionInfo<'a> { actions: &'a Vec, read_whole_table: bool, ) -> Self { + let mut read_app_ids = HashSet::::new(); + for action in actions.iter() { + if let Action::Txn(Txn { app_id, .. }) = action { + read_app_ids.insert(app_id.clone()); + } + } Self { txn_id: "".into(), read_predicates, @@ -156,10 +171,16 @@ impl<'a> TransactionInfo<'a> { actions: &'a Vec, read_whole_table: bool, ) -> DeltaResult { + let mut read_app_ids = HashSet::::new(); + for action in actions.iter() { + if let Action::Txn(Txn { app_id, .. }) = action { + read_app_ids.insert(app_id.clone()); + } + } Ok(Self { txn_id: "".into(), read_predicates, - read_app_ids: Default::default(), + read_app_ids, actions, read_snapshot, read_whole_table, diff --git a/crates/core/src/operations/transaction/mod.rs b/crates/core/src/operations/transaction/mod.rs index 6d5f7f731a..3a65a5e879 100644 --- a/crates/core/src/operations/transaction/mod.rs +++ b/crates/core/src/operations/transaction/mod.rs @@ -83,20 +83,22 @@ use serde_json::Value; use std::collections::HashMap; use self::conflict_checker::{CommitConflictError, TransactionInfo, WinningCommitSummary}; -use crate::checkpoints::create_checkpoint_for; +// use crate::checkpoints::create_checkpoint_for; use crate::errors::DeltaTableError; use crate::kernel::{ - Action, CommitInfo, EagerSnapshot, Metadata, Protocol, ReaderFeatures, WriterFeatures, + Action, CommitInfo, EagerSnapshot, Metadata, Protocol, ReaderFeatures, Txn, WriterFeatures, }; use crate::logstore::LogStoreRef; use crate::protocol::DeltaOperation; use crate::storage::ObjectStoreRetryExt; use crate::table::config::TableConfig; -use crate::table::state::DeltaTableState; +use crate::table::state::{DeltaTableState}; use crate::{crate_version, DeltaResult}; pub use self::protocol::INSTANCE as PROTOCOL; +#[cfg(test)] +pub(crate) mod application; mod conflict_checker; mod protocol; #[cfg(feature = "datafusion")] @@ -247,6 +249,7 @@ impl TableReference for DeltaTableState { } /// Data that was actually written to the log store. +#[derive(Debug)] pub struct CommitData { /// The actions pub actions: Vec, @@ -254,6 +257,8 @@ pub struct CommitData { pub operation: DeltaOperation, /// The Metadata pub app_metadata: HashMap, + /// Application specific transaction + pub app_transactions: Vec, } impl CommitData { @@ -262,6 +267,7 @@ impl CommitData { mut actions: Vec, operation: DeltaOperation, mut app_metadata: HashMap, + app_transactions: Vec, ) -> Result { if !actions.iter().any(|a| matches!(a, Action::CommitInfo(..))) { let mut commit_info = operation.get_commit_info(); @@ -274,10 +280,16 @@ impl CommitData { commit_info.info = app_metadata.clone(); actions.push(Action::CommitInfo(commit_info)) } + + for txn in &app_transactions { + actions.push(Action::Txn(txn.clone())) + } + Ok(CommitData { actions, operation, app_metadata, + app_transactions, }) } @@ -302,27 +314,29 @@ impl CommitData { } } -#[derive(Clone, Debug, Copy)] -/// Properties for post commit hook. -pub struct PostCommitHookProperties { - create_checkpoint: bool, -} +// #[derive(Clone, Debug, Copy)] +// /// Properties for post commit hook. +// pub struct PostCommitHookProperties { +// create_checkpoint: bool, +// } #[derive(Clone, Debug)] /// End user facing interface to be used by operations on the table. /// Enable controling commit behaviour and modifying metadata that is written during a commit. pub struct CommitProperties { pub(crate) app_metadata: HashMap, + pub(crate) app_transaction: Vec, max_retries: usize, - create_checkpoint: bool, + // create_checkpoint: bool, } impl Default for CommitProperties { fn default() -> Self { Self { app_metadata: Default::default(), + app_transaction: Vec::new(), max_retries: DEFAULT_RETRIES, - create_checkpoint: true, + // create_checkpoint: true, } } } @@ -331,15 +345,27 @@ impl CommitProperties { /// Specify metadata the be comitted pub fn with_metadata( mut self, - metadata: impl IntoIterator, + metadata: impl IntoIterator, ) -> Self { self.app_metadata = HashMap::from_iter(metadata); self } /// Specify if it should create a checkpoint when the commit interval condition is met - pub fn with_create_checkpoint(mut self, create_checkpoint: bool) -> Self { - self.create_checkpoint = create_checkpoint; + // pub fn with_create_checkpoint(mut self, create_checkpoint: bool) -> Self { + // self.create_checkpoint = create_checkpoint; + // self + // } + + /// Add an additonal application transaction to the commit + pub fn with_application_transaction(mut self, txn: Txn) -> Self { + self.app_transaction.push(txn); + self + } + + /// Override application transactions for the commit + pub fn with_application_transactions(mut self, txn: Vec) -> Self { + self.app_transaction = txn; self } } @@ -349,10 +375,11 @@ impl From for CommitBuilder { CommitBuilder { max_retries: value.max_retries, app_metadata: value.app_metadata, - post_commit_hook: PostCommitHookProperties { - create_checkpoint: value.create_checkpoint, - } - .into(), + app_transaction: value.app_transaction, + // post_commit_hook: PostCommitHookProperties { + // create_checkpoint: value.create_checkpoint, + // } + // .into(), ..Default::default() } } @@ -362,8 +389,9 @@ impl From for CommitBuilder { pub struct CommitBuilder { actions: Vec, app_metadata: HashMap, + app_transaction: Vec, max_retries: usize, - post_commit_hook: Option, + // post_commit_hook: Option, } impl Default for CommitBuilder { @@ -371,8 +399,9 @@ impl Default for CommitBuilder { CommitBuilder { actions: Vec::new(), app_metadata: HashMap::new(), + app_transaction: Vec::new(), max_retries: DEFAULT_RETRIES, - post_commit_hook: None, + // post_commit_hook: None, } } } @@ -397,10 +426,10 @@ impl<'a> CommitBuilder { } /// Specify all the post commit hook properties - pub fn with_post_commit_hook(mut self, post_commit_hook: PostCommitHookProperties) -> Self { - self.post_commit_hook = post_commit_hook.into(); - self - } + // pub fn with_post_commit_hook(mut self, post_commit_hook: PostCommitHookProperties) -> Self { + // self.post_commit_hook = post_commit_hook.into(); + // self + // } /// Prepare a Commit operation using the configured builder pub fn build( @@ -409,13 +438,18 @@ impl<'a> CommitBuilder { log_store: LogStoreRef, operation: DeltaOperation, ) -> Result, CommitBuilderError> { - let data = CommitData::new(self.actions, operation, self.app_metadata)?; + let data = CommitData::new( + self.actions, + operation, + self.app_metadata, + self.app_transaction, + )?; Ok(PreCommit { log_store, table_data, max_retries: self.max_retries, data, - post_commit_hook: self.post_commit_hook, + // post_commit_hook: self.post_commit_hook, }) } } @@ -426,7 +460,7 @@ pub struct PreCommit<'a> { table_data: Option<&'a dyn TableReference>, data: CommitData, max_retries: usize, - post_commit_hook: Option, + // post_commit_hook: Option, } impl<'a> std::future::IntoFuture for PreCommit<'a> { @@ -436,7 +470,7 @@ impl<'a> std::future::IntoFuture for PreCommit<'a> { fn into_future(self) -> Self::IntoFuture { let this = self; - Box::pin(async move { this.into_prepared_commit_future().await?.await?.await }) + Box::pin(async move { this.into_prepared_commit_future().await?.await }) } } @@ -466,7 +500,7 @@ impl<'a> PreCommit<'a> { table_data: this.table_data, max_retries: this.max_retries, data: this.data, - post_commit: this.post_commit_hook, + // post_commit: this.post_commit_hook, }) }) } @@ -479,7 +513,7 @@ pub struct PreparedCommit<'a> { data: CommitData, table_data: Option<&'a dyn TableReference>, max_retries: usize, - post_commit: Option, + // post_commit: Option, } impl<'a> PreparedCommit<'a> { @@ -490,7 +524,7 @@ impl<'a> PreparedCommit<'a> { } impl<'a> std::future::IntoFuture for PreparedCommit<'a> { - type Output = DeltaResult>; + type Output = DeltaResult; type IntoFuture = BoxFuture<'a, Self::Output>; fn into_future(self) -> Self::IntoFuture { @@ -501,12 +535,9 @@ impl<'a> std::future::IntoFuture for PreparedCommit<'a> { if this.table_data.is_none() { this.log_store.write_commit_entry(0, tmp_commit).await?; - return Ok(PostCommit { + return Ok(FinalizedCommit { version: 0, data: this.data, - create_checkpoint: false, - log_store: this.log_store, - table_data: this.table_data, }); } @@ -525,16 +556,10 @@ impl<'a> std::future::IntoFuture for PreparedCommit<'a> { let version = read_snapshot.version() + attempt_number as i64; match this.log_store.write_commit_entry(version, tmp_commit).await { Ok(()) => { - return Ok(PostCommit { + return Ok(FinalizedCommit { version, data: this.data, - create_checkpoint: this - .post_commit - .map(|v| v.create_checkpoint) - .unwrap_or_default(), - log_store: this.log_store, - table_data: this.table_data, - }); + }) } Err(TransactionError::VersionAlreadyExists(version)) => { let summary = WinningCommitSummary::try_new( @@ -542,7 +567,7 @@ impl<'a> std::future::IntoFuture for PreparedCommit<'a> { version - 1, version, ) - .await?; + .await?; let transaction_info = TransactionInfo::try_new( read_snapshot, this.data.operation.read_predicate(), @@ -583,54 +608,56 @@ impl<'a> std::future::IntoFuture for PreparedCommit<'a> { } /// Represents items for the post commit hook -pub struct PostCommit<'a> { - /// The winning version number of the commit - pub version: i64, - /// The data that was comitted to the log store - pub data: CommitData, - create_checkpoint: bool, - log_store: LogStoreRef, - table_data: Option<&'a dyn TableReference>, -} - -impl<'a> PostCommit<'a> { - /// Runs the post commit activities - async fn run_post_commit_hook( - &self, - version: i64, - commit_data: &CommitData, - ) -> DeltaResult<()> { - if self.create_checkpoint { - self.create_checkpoint(&self.table_data, &self.log_store, version, commit_data) - .await? - } - Ok(()) - } - async fn create_checkpoint( - &self, - table: &Option<&'a dyn TableReference>, - log_store: &LogStoreRef, - version: i64, - commit_data: &CommitData, - ) -> DeltaResult<()> { - if let Some(table) = table { - let checkpoint_interval = table.config().checkpoint_interval() as i64; - if ((version + 1) % checkpoint_interval) == 0 { - // We have to advance the snapshot otherwise we can't create a checkpoint - let mut snapshot = table.eager_snapshot().unwrap().clone(); - snapshot.advance(vec![commit_data])?; - let state = DeltaTableState { - app_transaction_version: HashMap::new(), - snapshot, - }; - create_checkpoint_for(version, &state, log_store.as_ref()).await? - } - } - Ok(()) - } -} +// pub struct PostCommit<'a> { +// /// The winning version number of the commit +// pub version: i64, +// /// The data that was comitted to the log store +// pub data: CommitData, +// create_checkpoint: bool, +// log_store: LogStoreRef, +// table_data: Option<&'a dyn TableReference>, +// } + +// impl<'a> PostCommit<'a> { +// /// Runs the post commit activities +// async fn run_post_commit_hook( +// &self, +// version: i64, +// commit_data: &CommitData, +// ) -> DeltaResult<()> { +// if self.create_checkpoint { +// // self.create_checkpoint(&self.table_data, &self.log_store, version, commit_data) +// // .await? +// } +// Ok(()) +// } +// // async fn create_checkpoint( +// // &self, +// // table: &Option<&'a dyn TableReference>, +// // log_store: &LogStoreRef, +// // version: i64, +// // commit_data: &CommitData, +// // ) -> DeltaResult<()> { +// // if let Some(table) = table { +// // let checkpoint_interval = table.config().checkpoint_interval() as i64; +// // if ((version + 1) % checkpoint_interval) == 0 { +// // // We have to advance the snapshot otherwise we can't create a checkpoint +// // let mut snapshot = table.eager_snapshot().unwrap().clone(); +// // let mut app_visitor = AppTransactionVisitor::new(); +// // snapshot.advance(vec![commit_data], vec![&mut app_visitor])?; +// // let state = DeltaTableState { +// // app_transaction_version: app_visitor.app_transaction_version, +// // snapshot, +// // }; +// // create_checkpoint_for(version, &state, log_store.as_ref()).await? +// // } +// // } +// // Ok(()) +// // } +// } /// A commit that successfully completed +#[derive(Debug)] pub struct FinalizedCommit { /// The winning version number of the commit pub version: i64, @@ -650,26 +677,26 @@ impl FinalizedCommit { } } -impl<'a> std::future::IntoFuture for PostCommit<'a> { - type Output = DeltaResult; - type IntoFuture = BoxFuture<'a, Self::Output>; - - fn into_future(self) -> Self::IntoFuture { - let this = self; - - Box::pin(async move { - match this.run_post_commit_hook(this.version, &this.data).await { - Ok(_) => { - return Ok(FinalizedCommit { - version: this.version, - data: this.data, - }) - } - Err(err) => return Err(err), - }; - }) - } -} +// impl<'a> std::future::IntoFuture for PostCommit<'a> { +// type Output = DeltaResult; +// type IntoFuture = BoxFuture<'a, Self::Output>; +// +// fn into_future(self) -> Self::IntoFuture { +// let this = self; +// +// Box::pin(async move { +// match this.run_post_commit_hook(this.version, &this.data).await { +// Ok(_) => { +// return Ok(FinalizedCommit { +// version: this.version, +// data: this.data, +// }) +// } +// Err(err) => return Err(err), +// }; +// }) +// } +// } #[cfg(test)] mod tests { diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 10b48a768c..0c1aeb41ad 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -813,6 +813,7 @@ impl std::future::IntoFuture for WriteBuilder { // then again, having only some tombstones may be misleading. if let Some(mut snapshot) = this.snapshot { snapshot.merge(commit.data.actions, &commit.data.operation, commit.version)?; + Ok(DeltaTable::new_with_state(this.log_store, snapshot)) } else { let mut table = DeltaTable::new(this.log_store, Default::default()); diff --git a/crates/core/src/table/state.rs b/crates/core/src/table/state.rs index ef2b97aa80..4566541ef9 100644 --- a/crates/core/src/table/state.rs +++ b/crates/core/src/table/state.rs @@ -3,6 +3,8 @@ use std::collections::HashMap; use std::sync::Arc; +use arrow::compute::{filter_record_batch, is_not_null}; +use arrow_array::{Array, Int64Array, StringArray, StructArray}; use chrono::Utc; use futures::TryStreamExt; use object_store::{path::Path, ObjectStore}; @@ -10,9 +12,10 @@ use serde::{Deserialize, Serialize}; use super::config::TableConfig; use super::{get_partition_col_data_types, DeltaTableConfig}; +use crate::kernel::arrow::extract as ex; use crate::kernel::{ Action, Add, DataType, EagerSnapshot, LogDataHandler, LogicalFile, Metadata, Protocol, Remove, - StructType, + ReplayVisitor, StructType, }; use crate::logstore::LogStore; use crate::operations::transaction::CommitData; @@ -20,6 +23,55 @@ use crate::partitions::{DeltaTablePartition, PartitionFilter}; use crate::protocol::DeltaOperation; use crate::{DeltaResult, DeltaTableError}; +pub(crate) struct AppTransactionVisitor { + app_transaction_version: HashMap, +} + +impl AppTransactionVisitor { + pub(crate) fn new() -> Self { + Self { + app_transaction_version: HashMap::new(), + } + } +} + +impl AppTransactionVisitor { + pub fn merge(self, map: &HashMap) -> HashMap { + let mut clone = map.clone(); + for (key, value) in self.app_transaction_version { + clone.insert(key, value); + } + + return clone; + } +} + +impl ReplayVisitor for AppTransactionVisitor { + fn visit_batch(&mut self, batch: &arrow_array::RecordBatch) -> DeltaResult<()> { + if batch.column_by_name("txn").is_none() { + return Ok(()); + } + + let txn_col = ex::extract_and_cast::(batch, "txn")?; + let filter = is_not_null(txn_col)?; + + let filtered = filter_record_batch(batch, &filter)?; + let arr = ex::extract_and_cast::(&filtered, "txn")?; + + let id = ex::extract_and_cast::(arr, "appId")?; + let version = ex::extract_and_cast::(arr, "version")?; + + for idx in 0..id.len() { + let app = ex::read_str(id, idx)?; + let version = ex::read_primitive(version, idx)?; + + self.app_transaction_version.insert(app.to_owned(), version); + } + + Ok(()) + } +} + /// State snapshot currently held by the Delta Table instance. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -36,10 +88,19 @@ impl DeltaTableState { config: DeltaTableConfig, version: Option, ) -> DeltaResult { - let snapshot = EagerSnapshot::try_new(table_root, store.clone(), config, version).await?; + let mut app_visitor = AppTransactionVisitor::new(); + let visitors: Vec<&mut dyn ReplayVisitor> = vec![&mut app_visitor]; + let snapshot = EagerSnapshot::try_new_with_visitor( + table_root, + store.clone(), + config, + version, + visitors, + ) + .await?; Ok(Self { snapshot, - app_transaction_version: HashMap::new(), + app_transaction_version: app_visitor.app_transaction_version, }) } @@ -83,6 +144,7 @@ impl DeltaTableState { metadata: metadata.clone(), }, HashMap::new(), + Vec::new(), ) .unwrap()]; @@ -194,8 +256,16 @@ impl DeltaTableState { actions, operation: operation.clone(), app_metadata: HashMap::new(), + app_transactions: Vec::new(), }; - let new_version = self.snapshot.advance(&vec![commit_data])?; + + let mut app_txn_visitor = AppTransactionVisitor::new(); + let new_version = self + .snapshot + .advance(&vec![commit_data], vec![&mut app_txn_visitor])?; + + self.app_transaction_version = app_txn_visitor.merge(&self.app_transaction_version); + if new_version != version { return Err(DeltaTableError::Generic("Version mismatch".to_string())); } @@ -213,7 +283,11 @@ impl DeltaTableState { log_store: Arc, version: Option, ) -> Result<(), DeltaTableError> { - self.snapshot.update(log_store, version).await?; + let mut app_txn_visitor = AppTransactionVisitor::new(); + self.snapshot + .update(log_store, version, vec![&mut app_txn_visitor]) + .await?; + self.app_transaction_version = app_txn_visitor.merge(&self.app_transaction_version); Ok(()) }