diff --git a/Cargo.lock b/Cargo.lock index dbe361720..b3937a8ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4751,15 +4751,20 @@ version = "0.10.0" dependencies = [ "arrow-array", "arrow-schema", + "arrow-select", "derive_more", "error-stack", + "futures", "itertools 0.11.0", "smallvec", "sparrow-api", "sparrow-compiler", + "sparrow-plan", "sparrow-runtime", "sparrow-syntax", "static_init", + "tokio", + "tokio-stream", "uuid 1.4.1", ] diff --git a/crates/sparrow-compiler/src/data_context.rs b/crates/sparrow-compiler/src/data_context.rs index 58afec392..3e9b704a6 100644 --- a/crates/sparrow-compiler/src/data_context.rs +++ b/crates/sparrow-compiler/src/data_context.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use anyhow::Context; use arrow::datatypes::{DataType, SchemaRef}; +use arrow::record_batch::RecordBatch; use sparrow_api::kaskada::v1alpha::slice_plan::Slice; use sparrow_api::kaskada::v1alpha::{compute_table, ComputeTable, PreparedFile, TableConfig}; use sparrow_core::context_code; @@ -17,7 +18,7 @@ use crate::AstDfgRef; /// /// Specifically, this holds the information about the tables /// available to the compilation. -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] pub struct DataContext { /// Information about the groupings in the context. group_info: Vec, @@ -60,6 +61,10 @@ impl DataContext { self.table_info.get(&id) } + pub fn table_info_mut(&mut self, id: TableId) -> Option<&mut TableInfo> { + self.table_info.get_mut(&id) + } + pub fn tables_for_grouping(&self, id: GroupId) -> impl Iterator { self.table_info .iter() @@ -296,7 +301,7 @@ impl DataContext { } /// Information about groups. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct GroupInfo { name: String, key_type: DataType, @@ -314,6 +319,8 @@ pub struct TableInfo { /// Each file set corresponds to the files for the table with a specific /// slice configuration. file_sets: Vec, + /// An in-memory record batch for the contents of the table. + pub in_memory: Option, } impl TableInfo { @@ -335,6 +342,7 @@ impl TableInfo { schema, config, file_sets, + in_memory: None, }) } diff --git a/crates/sparrow-compiler/src/dfg.rs b/crates/sparrow-compiler/src/dfg.rs index 9d074c37e..2adebdd0b 100644 --- a/crates/sparrow-compiler/src/dfg.rs +++ b/crates/sparrow-compiler/src/dfg.rs @@ -364,7 +364,7 @@ impl Dfg { } /// Runs simplifications on the graph. - pub(crate) fn run_simplifications(&mut self, options: &CompilerOptions) -> anyhow::Result<()> { + pub fn run_simplifications(&mut self, options: &CompilerOptions) -> anyhow::Result<()> { let span = info_span!("Running simplifications"); let _enter = span.enter(); @@ -375,7 +375,7 @@ impl Dfg { } /// Extract the simplest representation of the node `id` from the graph. - pub(crate) fn extract_simplest(&self, id: Id) -> DfgExpr { + pub fn extract_simplest(&self, id: Id) -> DfgExpr { let span = info_span!("Extracting simplest DFG"); let _enter = span.enter(); @@ -447,7 +447,7 @@ impl Dfg { /// Remove nodes that aren't needed for the `output` from the graph. /// /// Returns the new ID of the `output`. - pub(crate) fn prune(&mut self, output: Id) -> anyhow::Result { + pub fn prune(&mut self, output: Id) -> anyhow::Result { // The implementation is somewhat painful -- we extract a `RecExpr`, and then // recreate the EGraph. This has the desired property -- only referenced nodes // are extracted. But, it may cause the IDs to change. diff --git a/crates/sparrow-compiler/src/lib.rs b/crates/sparrow-compiler/src/lib.rs index e881b26be..ce74de1ab 100644 --- a/crates/sparrow-compiler/src/lib.rs +++ b/crates/sparrow-compiler/src/lib.rs @@ -41,7 +41,7 @@ mod frontend; mod functions; mod nearest_matches; mod options; -mod plan; +pub mod plan; mod time_domain; mod types; diff --git a/crates/sparrow-compiler/src/plan.rs b/crates/sparrow-compiler/src/plan.rs index 28dab43bb..e991094e5 100644 --- a/crates/sparrow-compiler/src/plan.rs +++ b/crates/sparrow-compiler/src/plan.rs @@ -24,7 +24,7 @@ const DBG_PRINT_PLAN: bool = false; /// TODO: The `DataContext` is used to get the table name from an ID, which is /// only necessary to create the `slice_plan` because it uses a name instead of /// an ID. -pub(super) fn extract_plan_proto( +pub fn extract_plan_proto( data_context: &DataContext, expr: DfgExpr, per_entity_behavior: PerEntityBehavior, diff --git a/crates/sparrow-main/tests/e2e/fixture/query_fixture.rs b/crates/sparrow-main/tests/e2e/fixture/query_fixture.rs index e65ff9bb9..3b57daf34 100644 --- a/crates/sparrow-main/tests/e2e/fixture/query_fixture.rs +++ b/crates/sparrow-main/tests/e2e/fixture/query_fixture.rs @@ -270,7 +270,6 @@ impl QueryFixture { plan: Some(plan), destination: Some(output_to), tables: data.tables(), - ..self.execute_request.clone() }; diff --git a/crates/sparrow-runtime/src/execute.rs b/crates/sparrow-runtime/src/execute.rs index b20b2c2b7..b8a6012fe 100644 --- a/crates/sparrow-runtime/src/execute.rs +++ b/crates/sparrow-runtime/src/execute.rs @@ -1,18 +1,18 @@ use std::sync::Arc; use chrono::NaiveDateTime; +use enum_map::EnumMap; use error_stack::{IntoReport, IntoReportCompat, ResultExt}; use futures::Stream; use prost_wkt_types::Timestamp; use sparrow_api::kaskada::v1alpha::execute_request::Limits; use sparrow_api::kaskada::v1alpha::{ - ComputePlan, ComputeTable, ExecuteRequest, ExecuteResponse, LateBoundValue, PerEntityBehavior, + ComputePlan, ComputeSnapshotConfig, ComputeTable, ExecuteRequest, ExecuteResponse, + LateBoundValue, PerEntityBehavior, PlanHash, }; use sparrow_arrow::scalar_value::ScalarValue; use sparrow_compiler::{hash_compute_plan_proto, DataContext}; -use sparrow_instructions::ComputeStore; use sparrow_qfr::kaskada::sparrow::v1alpha::FlightRecordHeader; -use tracing::Instrument; use crate::execute::error::Error; use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; @@ -23,6 +23,7 @@ use crate::RuntimeOptions; mod checkpoints; mod compute_executor; +mod compute_store_guard; pub mod error; pub(crate) mod key_hash_inverse; pub(crate) mod operation; @@ -31,9 +32,6 @@ mod progress_reporter; mod spawner; pub use compute_executor::*; -// The path prefix to the local compute store db. -const STORE_PATH_PREFIX: &str = "compute_snapshot_"; - /// The main method for executing a Fenl query. /// /// The `request` proto contains the execution plan as well as @@ -55,112 +53,136 @@ pub async fn execute( let destination = Destination::try_from(destination).change_context(Error::InvalidDestination)?; - let changed_since_time = request.changed_since.unwrap_or(Timestamp { - seconds: 0, - nanos: 0, - }); + let data_context = DataContext::try_from_tables(request.tables.to_vec()) + .into_report() + .change_context(Error::internal_msg("create data context"))?; - // Create and populate the late bindings. - // We don't use the `enum_map::enum_map!(...)` initialization because it would - // require looping over (and cloning) the scalar value unnecessarily. - let mut late_bindings = enum_map::enum_map! { - _ => None - }; - late_bindings[LateBoundValue::ChangedSinceTime] = Some(ScalarValue::timestamp( - changed_since_time.seconds, - changed_since_time.nanos, - None, - )); - - let output_at_time = if let Some(output_at_time) = request.final_result_time { - late_bindings[LateBoundValue::FinalAtTime] = Some(ScalarValue::timestamp( - output_at_time.seconds, - output_at_time.nanos, - None, - )); - Some(output_at_time) - } else { - late_bindings[LateBoundValue::FinalAtTime] = None; - None + let options = ExecutionOptions { + bounded_lateness_ns, + changed_since_time: request.changed_since.unwrap_or_default(), + final_at_time: request.final_result_time, + compute_snapshot_config: request.compute_snapshot_config, + limits: request.limits, + ..ExecutionOptions::default() }; - let mut data_context = DataContext::try_from_tables(request.tables.to_vec()) - .into_report() - .change_context(Error::internal_msg("create data context"))?; + // let output_at_time = request.final_result_time; - let object_stores = Arc::new(ObjectStoreRegistry::default()); + execute_new(plan, destination, data_context, options).await +} - // If the snapshot config exists, sparrow should attempt to resume from state, - // and store new state. Create a new storage path for the local store to - // exist. - let storage_dir = if let Some(config) = &request.compute_snapshot_config { - let dir = tempfile::Builder::new() - .prefix(&STORE_PATH_PREFIX) - .tempdir() - .into_report() - .change_context(Error::internal_msg("create snapshot dir"))?; - - // If a `resume_from` path is specified, download the existing state from s3. - if let Some(resume_from) = &config.resume_from { - checkpoints::download(resume_from, object_stores.as_ref(), dir.path(), config) - .instrument(tracing::info_span!("Downloading checkpoint files")) - .await - .change_context(Error::internal_msg("download snapshot"))?; +#[derive(Default, Debug)] +pub struct ExecutionOptions { + changed_since_time: Timestamp, + final_at_time: Option, + bounded_lateness_ns: Option, + compute_snapshot_config: Option, + limits: Option, + stop_signal_rx: Option>, +} + +impl ExecutionOptions { + pub fn late_bindings(&self) -> EnumMap> { + enum_map::enum_map! { + LateBoundValue::ChangedSinceTime => Some(ScalarValue::timestamp( + self.changed_since_time.seconds, + self.changed_since_time.nanos, + None, + )), + LateBoundValue::FinalAtTime => self.final_at_time.as_ref().map(|t| ScalarValue::timestamp( + t.seconds, + t.nanos, + None, + )), + _ => None + } + } + + pub fn set_changed_since(&mut self, changed_since: Timestamp) { + self.changed_since_time = changed_since; + } + + pub fn set_final_at_time(&mut self, final_at_time: Timestamp) { + self.final_at_time = Some(final_at_time); + } + + async fn compute_store( + &self, + object_stores: &ObjectStoreRegistry, + per_entity_behavior: PerEntityBehavior, + plan_hash: &PlanHash, + ) -> error_stack::Result, Error> { + // If the snapshot config exists, sparrow should attempt to resume from state, + // and store new state. Create a new storage path for the local store to + // exist. + if let Some(config) = self.compute_snapshot_config.clone() { + let max_allowed_max_event_time = match per_entity_behavior { + PerEntityBehavior::Unspecified => { + error_stack::bail!(Error::UnspecifiedPerEntityBehavior) + } + PerEntityBehavior::All => { + // For all results, we need a snapshot with a maximum event time + // no larger than the changed_since time, since we need to replay + // (and recompute the results for) all events after the changed + // since time. + self.changed_since_time.clone() + } + PerEntityBehavior::Final => { + // This is a bit confusing. Right now, the manager is responsible for + // choosing a valid snapshot to resume from. Thus, the work of choosing + // a valid snapshot with regard to any new input data is already done. + // However, the engine does a sanity check here to ensure the snapshot's + // max event time is before the allowed max event time the engine supports, + // dependent on the entity behavior of the query. + // + // For FinalResults, the snapshot can have a max event time of "any time", + // so we set this to Timestamp::MAX. This is because we just need to be able + // to produce results once after all new events have been processed, and + // we can already assume a valid snapshot is chosen and the correct input + // files are being processed. + Timestamp { + seconds: i64::MAX, + nanos: i32::MAX, + } + } + PerEntityBehavior::FinalAtTime => { + self.final_at_time.clone().expect("final at time") + } + }; + + let guard = compute_store_guard::ComputeStoreGuard::try_new( + config, + object_stores, + max_allowed_max_event_time, + plan_hash, + ) + .await?; + + Ok(Some(guard)) } else { - tracing::info!("No snapshot set to resume from. Using empty compute store."); + tracing::info!("No snapshot config; not creating compute store."); + Ok(None) } + } +} - Some(dir) - } else { - tracing::info!("No snapshot config; not creating compute store."); - None - }; +pub async fn execute_new( + plan: ComputePlan, + destination: Destination, + mut data_context: DataContext, + options: ExecutionOptions, +) -> error_stack::Result>, Error> { + let object_stores = Arc::new(ObjectStoreRegistry::default()); let plan_hash = hash_compute_plan_proto(&plan); - let compute_store = if let Some(dir) = &storage_dir { - let max_allowed_max_event_time = match plan.per_entity_behavior() { - PerEntityBehavior::Unspecified => { - error_stack::bail!(Error::UnspecifiedPerEntityBehavior) - } - PerEntityBehavior::All => { - // For all results, we need a snapshot with a maximum event time - // no larger than the changed_since time, since we need to replay - // (and recompute the results for) all events after the changed - // since time. - changed_since_time.clone() - } - PerEntityBehavior::Final => { - // This is a bit confusing. Right now, the manager is responsible for - // choosing a valid snapshot to resume from. Thus, the work of choosing - // a valid snapshot with regard to any new input data is already done. - // However, the engine does a sanity check here to ensure the snapshot's - // max event time is before the allowed max event time the engine supports, - // dependent on the entity behavior of the query. - // - // For FinalResults, the snapshot can have a max event time of "any time", - // so we set this to Timestamp::MAX. This is because we just need to be able - // to produce results once after all new events have been processed, and - // we can already assume a valid snapshot is chosen and the correct input - // files are being processed. - Timestamp { - seconds: i64::MAX, - nanos: i32::MAX, - } - } - PerEntityBehavior::FinalAtTime => { - output_at_time.as_ref().expect("final at time").clone() - } - }; - - Some( - ComputeStore::try_new(dir.path(), &max_allowed_max_event_time, &plan_hash) - .into_report() - .change_context(Error::internal_msg("loading compute store"))?, + let compute_store = options + .compute_store( + object_stores.as_ref(), + plan.per_entity_behavior(), + &plan_hash, ) - } else { - None - }; + .await?; let primary_grouping_key_type = plan .primary_grouping_key_type @@ -170,13 +192,14 @@ pub async fn execute( arrow::datatypes::DataType::try_from(&primary_grouping_key_type) .into_report() .change_context(Error::internal_msg("decode primary_grouping_key_type"))?; - let mut key_hash_inverse = KeyHashInverse::from_data_type(primary_grouping_key_type.clone()); - if let Some(compute_store) = compute_store.to_owned() { - if let Ok(restored) = KeyHashInverse::restore_from(&compute_store) { + let mut key_hash_inverse = KeyHashInverse::from_data_type(primary_grouping_key_type.clone()); + if let Some(compute_store) = &compute_store { + if let Ok(restored) = KeyHashInverse::restore_from(compute_store.store_ref()) { key_hash_inverse = restored } } + let primary_group_id = data_context .get_or_create_group_id(&plan.primary_grouping, &primary_grouping_key_type) .into_report() @@ -192,52 +215,50 @@ pub async fn execute( let (progress_updates_tx, progress_updates_rx) = tokio::sync::mpsc::channel(29.max(plan.operations.len() * 2)); - let output_datetime = if let Some(t) = output_at_time { - Some( + let output_at_time = options + .final_at_time + .as_ref() + .map(|t| { NaiveDateTime::from_timestamp_opt(t.seconds, t.nanos as u32) - .ok_or_else(|| Error::internal_msg("expected valid timestamp"))?, - ) - } else { - None - }; + .ok_or_else(|| Error::internal_msg("expected valid timestamp")) + }) + .transpose()?; - // We use the plan hash for validating the snapshot is as expected. - // Rather than accepting it as input (which could lead to us getting - // a correct hash but an incorrect plan) we re-hash the plan. let context = OperationContext { plan, - plan_hash, object_stores, data_context, - compute_store, + compute_store: compute_store.as_ref().map(|g| g.store()), key_hash_inverse, max_event_in_snapshot: None, progress_updates_tx, - output_at_time: output_datetime, - bounded_lateness_ns, + output_at_time, + bounded_lateness_ns: options.bounded_lateness_ns, }; // Start executing the query. We pass the response channel to the // execution layer so it can periodically report progress. tracing::debug!("Starting query execution"); + let late_bindings = options.late_bindings(); let runtime_options = RuntimeOptions { - limits: request.limits.unwrap_or_default(), + limits: options.limits.unwrap_or_default(), flight_record_path: None, }; let compute_executor = ComputeExecutor::try_spawn( context, + plan_hash, &late_bindings, &runtime_options, progress_updates_rx, destination, - None, + options.stop_signal_rx, ) .await .change_context(Error::internal_msg("spawn compute executor"))?; - Ok(compute_executor.execute_with_progress(storage_dir, request.compute_snapshot_config)) + Ok(compute_executor.execute_with_progress(compute_store)) } /// The main method for starting a materialization process. @@ -251,102 +272,27 @@ pub async fn materialize( bounded_lateness_ns: Option, stop_signal_rx: tokio::sync::watch::Receiver, ) -> error_stack::Result>, Error> { - // TODO: Unimplemented feature - changed_since_time - let changed_since_time = Timestamp { - seconds: 0, - nanos: 0, - }; - - // Create and populate the late bindings. - // We don't use the `enum_map::enum_map!(...)` initialization because it would - // require looping over (and cloning) the scalar value unnecessarily. - let mut late_bindings = enum_map::enum_map! { - _ => None + let options = ExecutionOptions { + bounded_lateness_ns, + // TODO: Unimplemented feature - changed_since_time + changed_since_time: Timestamp { + seconds: 0, + nanos: 0, + }, + // Unsupported: not allowed to materialize at a specific time + final_at_time: None, + // TODO: Resuming from state is unimplemented + compute_snapshot_config: None, + stop_signal_rx: Some(stop_signal_rx), + ..ExecutionOptions::default() }; - late_bindings[LateBoundValue::ChangedSinceTime] = Some(ScalarValue::timestamp( - changed_since_time.seconds, - changed_since_time.nanos, - None, - )); - - // Unsupported: not allowed to materialize at a specific time - late_bindings[LateBoundValue::FinalAtTime] = None; - let output_at_time = None; - let mut data_context = DataContext::try_from_tables(tables) + let data_context = DataContext::try_from_tables(tables) .into_report() .change_context(Error::internal_msg("create data context"))?; - // TODO: Resuming from state is unimplemented - let storage_dir = None; - let snapshot_compute_store = None; - - let plan_hash = hash_compute_plan_proto(&plan); - - let primary_grouping_key_type = plan - .primary_grouping_key_type - .to_owned() - .ok_or(Error::MissingField("primary_grouping_key_type"))?; - let primary_grouping_key_type = - arrow::datatypes::DataType::try_from(&primary_grouping_key_type) - .into_report() - .change_context(Error::internal_msg("decode primary_grouping_key_type"))?; - let mut key_hash_inverse = KeyHashInverse::from_data_type(primary_grouping_key_type.clone()); - - let primary_group_id = data_context - .get_or_create_group_id(&plan.primary_grouping, &primary_grouping_key_type) - .into_report() - .change_context(Error::internal_msg("get primary grouping ID"))?; - - let object_stores = Arc::new(ObjectStoreRegistry::default()); - key_hash_inverse - .add_from_data_context(&data_context, primary_group_id, &object_stores) - .await - .change_context(Error::internal_msg("initialize key hash inverse"))?; - let key_hash_inverse = Arc::new(ThreadSafeKeyHashInverse::new(key_hash_inverse)); - - // Channel for the output stats. - let (progress_updates_tx, progress_updates_rx) = - tokio::sync::mpsc::channel(29.max(plan.operations.len() * 2)); - - // We use the plan hash for validating the snapshot is as expected. - // Rather than accepting it as input (which could lead to us getting - // a correct hash but an incorrect plan) we re-hash the plan. - let context = OperationContext { - plan, - plan_hash, - object_stores, - data_context, - compute_store: snapshot_compute_store, - key_hash_inverse, - max_event_in_snapshot: None, - progress_updates_tx, - output_at_time, - bounded_lateness_ns, - }; - - // Start executing the query. We pass the response channel to the - // execution layer so it can periodically report progress. - tracing::debug!("Starting query execution"); - - let runtime_options = RuntimeOptions { - limits: Limits::default(), - flight_record_path: None, - }; - - let compute_executor = ComputeExecutor::try_spawn( - context, - &late_bindings, - &runtime_options, - progress_updates_rx, - destination, - Some(stop_signal_rx), - ) - .await - .change_context(Error::internal_msg("spawn compute executor"))?; - // TODO: the `execute_with_progress` method contains a lot of additional logic that is theoretically not needed, // as the materialization does not exit, and should not need to handle cleanup tasks that regular // queries do. We should likely refactor this to use a separate `materialize_with_progress` method. - Ok(compute_executor.execute_with_progress(storage_dir, None)) + execute_new(plan, destination, data_context, options).await } diff --git a/crates/sparrow-runtime/src/execute/compute_executor.rs b/crates/sparrow-runtime/src/execute/compute_executor.rs index e5b2ba2a8..a580bebe5 100644 --- a/crates/sparrow-runtime/src/execute/compute_executor.rs +++ b/crates/sparrow-runtime/src/execute/compute_executor.rs @@ -7,18 +7,16 @@ use futures::stream::{FuturesUnordered, PollNext}; use futures::{FutureExt, Stream, TryFutureExt}; use prost_wkt_types::Timestamp; use sparrow_api::kaskada::v1alpha::ComputeSnapshot; -use sparrow_api::kaskada::v1alpha::ComputeSnapshotConfig; use sparrow_api::kaskada::v1alpha::{ExecuteResponse, LateBoundValue, PlanHash}; use sparrow_arrow::scalar_value::ScalarValue; -use sparrow_instructions::ComputeStore; use sparrow_qfr::io::writer::FlightRecordWriter; use sparrow_qfr::kaskada::sparrow::v1alpha::FlightRecordHeader; use sparrow_qfr::FlightRecorderFactory; -use tempfile::TempDir; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; use tracing::{error, info, info_span, Instrument}; +use crate::execute::compute_store_guard::ComputeStoreGuard; use crate::execute::operation::{OperationContext, OperationExecutor}; use crate::execute::output::Destination; use crate::execute::progress_reporter::{progress_stream, ProgressUpdate}; @@ -31,7 +29,6 @@ use crate::{Batch, RuntimeOptions}; pub(crate) struct ComputeExecutor { object_stores: Arc, - compute_store: Option>, plan_hash: PlanHash, futures: FuturesUnordered>, progress_updates_rx: tokio::sync::mpsc::Receiver, @@ -53,6 +50,7 @@ impl ComputeExecutor { /// Spawns the compute tasks using the new operation based executor. pub async fn try_spawn( mut context: OperationContext, + plan_hash: PlanHash, late_bindings: &EnumMap>, runtime_options: &RuntimeOptions, progress_updates_rx: tokio::sync::mpsc::Receiver, @@ -155,8 +153,7 @@ impl ComputeExecutor { Ok(Self { object_stores: context.object_stores, - compute_store: context.compute_store, - plan_hash: context.plan_hash, + plan_hash, futures: spawner.finish(), progress_updates_rx, max_event_time_rx, @@ -167,14 +164,12 @@ impl ComputeExecutor { /// /// The `finish` function is called after the final compute result has been /// created, but before progress information stops being streamed. - pub fn execute_with_progress( + pub(super) fn execute_with_progress( self, - storage_dir: Option, - compute_snapshot_config: Option, + store: Option, ) -> impl Stream> { let Self { object_stores, - compute_store, plan_hash, futures, progress_updates_rx, @@ -201,40 +196,16 @@ impl ComputeExecutor { }; let compute_result = compute_result.expect("ok"); - if let Some(compute_store) = compute_store { - // Write the max input time to the store. - if let Err(e) = compute_store - .put_max_event_time(&compute_result.max_input_timestamp) - .into_report() - { - return ProgressUpdate::ExecutionFailed { - error: e - .change_context(Error::Internal("failed to report max event time")), - }; - } - - // Now that everything has completed, we attempt to get the compute store out. - // This lets us explicitly drop the store here. - match Arc::try_unwrap(compute_store) { - Ok(owned_compute_store) => std::mem::drop(owned_compute_store), - Err(_) => panic!("unable to reclaim compute store"), - }; - } - - let compute_snapshots = upload_compute_snapshots( - object_stores, - storage_dir, - compute_snapshot_config, - compute_result, - ) - .instrument(tracing::info_span!("Uploading checkpoint files")) - .await - .unwrap_or_else(|e| { - // Log, but don't fail if we couldn't upload snapshots. - // We can still produce valid answers, but won't perform an incremental query. - error!("Failed to upload compute snapshot(s):\n{:?}", e); - Vec::new() - }); + let compute_snapshots = + upload_compute_snapshots(object_stores.as_ref(), store, compute_result) + .instrument(tracing::info_span!("Uploading checkpoint files")) + .await + .unwrap_or_else(|e| { + // Log, but don't fail if we couldn't upload snapshots. + // We can still produce valid answers, but won't perform an incremental query. + error!("Failed to upload compute snapshot(s):\n{:?}", e); + Vec::new() + }); Ok(ProgressUpdate::ExecutionComplete { compute_snapshots }) }; @@ -270,30 +241,14 @@ fn select_biased( } async fn upload_compute_snapshots( - object_stores: Arc, - storage_dir: Option, - compute_snapshot_config: Option, + object_stores: &ObjectStoreRegistry, + store: Option, compute_result: ComputeResult, ) -> error_stack::Result, Error> { let mut snapshots = Vec::new(); - // If a snapshot config exists, let's assume for now that this - // indicates we want to upload snapshots. - // - // There may be situations where we want to resume from a snapshot, - // but not upload new snapshots. - if let Some(snapshot_config) = compute_snapshot_config { - let storage_dir = storage_dir.ok_or(Error::Internal("missing storage dir"))?; - - let snapshot_metadata = super::checkpoints::upload( - object_stores.as_ref(), - storage_dir, - snapshot_config, - compute_result, - ) - .await - .change_context(Error::Internal("uploading snapshot"))?; - snapshots.push(snapshot_metadata); + if let Some(store) = store { + snapshots.push(store.finish(object_stores, compute_result).await?); } else { tracing::info!("No snapshot config; not uploading compute store.") } diff --git a/crates/sparrow-runtime/src/execute/compute_store_guard.rs b/crates/sparrow-runtime/src/execute/compute_store_guard.rs new file mode 100644 index 000000000..8ce574afc --- /dev/null +++ b/crates/sparrow-runtime/src/execute/compute_store_guard.rs @@ -0,0 +1,82 @@ +use std::sync::Arc; + +use error_stack::{IntoReport, IntoReportCompat, ResultExt}; +use prost_wkt_types::Timestamp; +use sparrow_api::kaskada::v1alpha::{ComputeSnapshot, ComputeSnapshotConfig, PlanHash}; +use sparrow_instructions::ComputeStore; +use tempfile::TempDir; +use tracing::Instrument; + +use crate::execute::error::Error; +use crate::execute::{checkpoints, ComputeResult}; +use crate::stores::ObjectStoreRegistry; + +pub(super) struct ComputeStoreGuard { + dir: TempDir, + store: Arc, + config: ComputeSnapshotConfig, +} + +// The path prefix to the local compute store db. +const STORE_PATH_PREFIX: &str = "compute_snapshot_"; + +impl ComputeStoreGuard { + pub async fn try_new( + config: ComputeSnapshotConfig, + object_stores: &ObjectStoreRegistry, + max_allowed_max_event_time: Timestamp, + plan_hash: &PlanHash, + ) -> error_stack::Result { + let dir = tempfile::Builder::new() + .prefix(&STORE_PATH_PREFIX) + .tempdir() + .into_report() + .change_context(Error::internal_msg("create snapshot dir"))?; + + // If a `resume_from` path is specified, download the existing state from s3. + if let Some(resume_from) = &config.resume_from { + checkpoints::download(resume_from, object_stores, dir.path(), &config) + .instrument(tracing::info_span!("Downloading checkpoint files")) + .await + .change_context(Error::internal_msg("download snapshot"))?; + } else { + tracing::info!("No snapshot set to resume from. Using empty compute store."); + } + + let store = ComputeStore::try_new(dir.path(), &max_allowed_max_event_time, plan_hash) + .into_report() + .change_context(Error::internal_msg("loading compute store"))?; + Ok(Self { dir, store, config }) + } + + pub async fn finish( + self, + object_stores: &ObjectStoreRegistry, + compute_result: ComputeResult, + ) -> error_stack::Result { + // Write the max input time to the store. + self.store + .put_max_event_time(&compute_result.max_input_timestamp) + .into_report() + .change_context(Error::Internal("failed to report max event time"))?; + + // Now that everything has completed, we attempt to get the compute store out. + // This lets us explicitly drop the store here. + match Arc::try_unwrap(self.store) { + Ok(owned_compute_store) => std::mem::drop(owned_compute_store), + Err(_) => panic!("unable to reclaim compute store"), + }; + + super::checkpoints::upload(object_stores, self.dir, self.config, compute_result) + .await + .change_context(Error::Internal("uploading snapshot")) + } + + pub fn store(&self) -> Arc { + self.store.clone() + } + + pub fn store_ref(&self) -> &ComputeStore { + self.store.as_ref() + } +} diff --git a/crates/sparrow-runtime/src/execute/key_hash_inverse.rs b/crates/sparrow-runtime/src/execute/key_hash_inverse.rs index 645ff753b..2e1433fab 100644 --- a/crates/sparrow-runtime/src/execute/key_hash_inverse.rs +++ b/crates/sparrow-runtime/src/execute/key_hash_inverse.rs @@ -1,7 +1,7 @@ use std::str::FromStr; use anyhow::Context; -use arrow::array::{Array, ArrayRef, PrimitiveArray, UInt64Array}; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray, UInt64Array}; use arrow::datatypes::{DataType, UInt64Type}; use error_stack::{IntoReportCompat, ResultExt}; @@ -114,6 +114,25 @@ impl KeyHashInverse { .change_context(Error::ReadingMetadata)?; } + // HACKY: Add the in-memory batches to the key hash inverse. + let in_memory = data_context + .tables_for_grouping(primary_grouping) + .flat_map(|table| { + table.in_memory.as_ref().map(|batch| { + let keys = batch + .column_by_name(&table.config().group_column_name) + .unwrap(); + let key_hashes = batch.columns()[2].clone(); + (keys.clone(), key_hashes.clone()) + }) + }); + for (keys, key_hashes) in in_memory { + self.add(keys, key_hashes.as_primitive()) + .into_report() + .change_context(Error::ReadingMetadata) + .unwrap(); + } + Ok(()) } diff --git a/crates/sparrow-runtime/src/execute/operation.rs b/crates/sparrow-runtime/src/execute/operation.rs index 96edac38e..8414b858d 100644 --- a/crates/sparrow-runtime/src/execute/operation.rs +++ b/crates/sparrow-runtime/src/execute/operation.rs @@ -44,9 +44,7 @@ use error_stack::{IntoReport, IntoReportCompat, Report, Result, ResultExt}; use futures::Future; use prost_wkt_types::Timestamp; use sparrow_api::kaskada::v1alpha::operation_plan::tick_operation::TickBehavior; -use sparrow_api::kaskada::v1alpha::{ - operation_plan, ComputePlan, LateBoundValue, OperationPlan, PlanHash, -}; +use sparrow_api::kaskada::v1alpha::{operation_plan, ComputePlan, LateBoundValue, OperationPlan}; use sparrow_arrow::scalar_value::ScalarValue; use sparrow_compiler::DataContext; use sparrow_instructions::ComputeStore; @@ -80,7 +78,6 @@ use crate::Batch; /// the method to create a table reader on to it. pub(crate) struct OperationContext { pub plan: ComputePlan, - pub plan_hash: PlanHash, pub object_stores: Arc, pub data_context: DataContext, pub compute_store: Option>, diff --git a/crates/sparrow-runtime/src/execute/operation/scan.rs b/crates/sparrow-runtime/src/execute/operation/scan.rs index c533ff01b..46e58ccda 100644 --- a/crates/sparrow-runtime/src/execute/operation/scan.rs +++ b/crates/sparrow-runtime/src/execute/operation/scan.rs @@ -156,6 +156,27 @@ impl ScanOperation { ))?, ); + if let Some(in_memory) = &table_info.in_memory { + // Hacky. When doing the Python-builder, use the in-memory batch. + // Ideally, this would be merged with the contents of the file. + // Bonus points if it deduplicates. That would allow us to use the + // in-memory batch as the "hot-store" for history+stream hybrid + // queries. + assert!(requested_slice.is_none()); + let batch = in_memory.clone(); + return Ok(Box::new(Self { + projected_schema, + input_stream: futures::stream::once(async move { + Batch::try_new_from_batch(batch) + .into_report() + .change_context(Error::internal_msg("invalid input")) + }) + .boxed(), + key_hash_index: KeyHashIndex::default(), + progress_updates_tx: context.progress_updates_tx.clone(), + })); + } + // Figure out the projected columns from the table schema. // // TODO: Can we clean anything up by changing the table reader API @@ -316,7 +337,7 @@ mod tests { use sparrow_api::kaskada::v1alpha::{self, data_type}; use sparrow_api::kaskada::v1alpha::{ expression_plan, literal, operation_plan, ComputePlan, ComputeTable, ExpressionPlan, - Literal, OperationInputRef, OperationPlan, PlanHash, PreparedFile, SlicePlan, TableConfig, + Literal, OperationInputRef, OperationPlan, PreparedFile, SlicePlan, TableConfig, TableMetadata, }; use sparrow_arrow::downcast::downcast_primitive_array; @@ -448,7 +469,6 @@ mod tests { operations: vec![plan], ..ComputePlan::default() }, - plan_hash: PlanHash::default(), object_stores: Arc::new(ObjectStoreRegistry::default()), data_context, compute_store: None, diff --git a/crates/sparrow-runtime/src/execute/operation/testing.rs b/crates/sparrow-runtime/src/execute/operation/testing.rs index fc619808b..33e3e27bd 100644 --- a/crates/sparrow-runtime/src/execute/operation/testing.rs +++ b/crates/sparrow-runtime/src/execute/operation/testing.rs @@ -4,7 +4,7 @@ use anyhow::Context; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; use itertools::Itertools; -use sparrow_api::kaskada::v1alpha::{ComputePlan, OperationPlan, PlanHash}; +use sparrow_api::kaskada::v1alpha::{ComputePlan, OperationPlan}; use sparrow_compiler::DataContext; use crate::execute::key_hash_inverse::{KeyHashInverse, ThreadSafeKeyHashInverse}; @@ -181,7 +181,6 @@ pub(super) async fn run_operation( operations: vec![plan], ..ComputePlan::default() }, - plan_hash: PlanHash::default(), object_stores: Arc::new(ObjectStoreRegistry::default()), data_context: DataContext::default(), compute_store: None, @@ -239,7 +238,6 @@ pub(super) async fn run_operation_json( operations: vec![plan], ..ComputePlan::default() }, - plan_hash: PlanHash::default(), object_stores: Arc::new(ObjectStoreRegistry::default()), data_context: DataContext::default(), compute_store: None, diff --git a/crates/sparrow-runtime/src/execute/progress_reporter.rs b/crates/sparrow-runtime/src/execute/progress_reporter.rs index f4e858712..c01310770 100644 --- a/crates/sparrow-runtime/src/execute/progress_reporter.rs +++ b/crates/sparrow-runtime/src/execute/progress_reporter.rs @@ -135,18 +135,15 @@ impl ProgressTracker { flight_record_path: None, plan_yaml_path: None, compute_snapshots: Vec::new(), - destination: Some(destination), + destination, }) } - fn destination_to_output(&mut self) -> error_stack::Result { + fn destination_to_output(&mut self) -> error_stack::Result, Error> { // Clone the output paths in for object store destinations - let destination = self - .destination - .as_ref() - .ok_or(Error::Internal("expected destination"))?; - match destination { - destination::Destination::ObjectStore(store) => Ok(Destination { + match self.destination.as_ref() { + None => Ok(None), + Some(destination::Destination::ObjectStore(store)) => Ok(Some(Destination { destination: Some(destination::Destination::ObjectStore( ObjectStoreDestination { file_type: store.file_type, @@ -156,18 +153,18 @@ impl ProgressTracker { }), }, )), - }), + })), #[cfg(not(feature = "pulsar"))] - output_to::Destination::Pulsar(pulsar) => { + Some(destination::Destination::Pulsar(pulsar)) => { error_stack::bail!(Error::FeatureNotEnabled { feature: "pulsar" }) } #[cfg(feature = "pulsar")] - destination::Destination::Pulsar(pulsar) => { + Some(destination::Destination::Pulsar(pulsar)) => { let config = pulsar .config .as_ref() .ok_or(Error::internal_msg("missing config"))?; - Ok(Destination { + Ok(Some(Destination { destination: Some(destination::Destination::Pulsar(PulsarDestination { config: Some(PulsarConfig { broker_service_url: config.broker_service_url.clone(), @@ -179,7 +176,7 @@ impl ProgressTracker { admin_service_url: config.admin_service_url.clone(), }), })), - }) + })) } } } @@ -234,7 +231,7 @@ pub(super) fn progress_stream( } } - let output = match tracker.destination_to_output() { + let destination = match tracker.destination_to_output() { Ok(output) => output, Err(e) => { yield Err(e); @@ -249,7 +246,7 @@ pub(super) fn progress_stream( flight_record_path: None, plan_yaml_path: None, compute_snapshots, - destination: Some(output), + destination, }); yield final_result; break diff --git a/crates/sparrow-session/Cargo.toml b/crates/sparrow-session/Cargo.toml index 7de9cf7d8..31787de90 100644 --- a/crates/sparrow-session/Cargo.toml +++ b/crates/sparrow-session/Cargo.toml @@ -12,15 +12,20 @@ The Sparrow session builder. [dependencies] arrow-array.workspace = true arrow-schema.workspace = true +arrow-select.workspace = true derive_more.workspace = true error-stack.workspace = true +futures.workspace = true itertools.workspace = true smallvec.workspace = true sparrow-api = { path = "../sparrow-api" } sparrow-compiler = { path = "../sparrow-compiler" } +sparrow-plan = { path = "../sparrow-plan" } sparrow-runtime = { path = "../sparrow-runtime" } sparrow-syntax = { path = "../sparrow-syntax" } static_init.workspace = true +tokio.workspace = true +tokio-stream.workspace = true uuid.workspace = true [dev-dependencies] diff --git a/crates/sparrow-session/src/error.rs b/crates/sparrow-session/src/error.rs index 5ef35968b..7a4321491 100644 --- a/crates/sparrow-session/src/error.rs +++ b/crates/sparrow-session/src/error.rs @@ -20,6 +20,10 @@ pub enum Error { Prepare, #[display(fmt = "internal error")] Internal, + #[display(fmt = "compile query")] + Compile, + #[display(fmt = "execute query")] + Execute, } impl error_stack::Context for Error {} diff --git a/crates/sparrow-session/src/session.rs b/crates/sparrow-session/src/session.rs index 737dda4a9..79a795f5a 100644 --- a/crates/sparrow-session/src/session.rs +++ b/crates/sparrow-session/src/session.rs @@ -1,10 +1,17 @@ use std::borrow::Cow; +use arrow_array::RecordBatch; use arrow_schema::SchemaRef; use error_stack::{IntoReport, IntoReportCompat, ResultExt}; +use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; -use sparrow_api::kaskada::v1alpha::{ComputeTable, FeatureSet, TableConfig, TableMetadata}; +use sparrow_api::kaskada::v1alpha::{ + ComputeTable, FeatureSet, PerEntityBehavior, TableConfig, TableMetadata, +}; use sparrow_compiler::{AstDfgRef, DataContext, Dfg, DiagnosticCollector}; +use sparrow_plan::TableId; +use sparrow_runtime::execute::output::Destination; +use sparrow_runtime::execute::ExecutionOptions; use sparrow_syntax::{ExprOp, LiteralValue, Located, Location, Resolved}; use uuid::Uuid; @@ -181,6 +188,87 @@ impl Session { Ok(result) } } + + pub fn execute(&self, expr: &Expr) -> error_stack::Result { + // TODO: Decorations? + let primary_group_info = self + .data_context + .group_info( + expr.0 + .grouping() + .expect("query to be grouped (non-literal)"), + ) + .expect("missing group info"); + let primary_grouping = primary_group_info.name().to_owned(); + let primary_grouping_key_type = primary_group_info.key_type(); + + // First, extract the necessary subset of the DFG as an expression. + // This will allow us to operate without mutating things. + let expr = self.dfg.extract_simplest(expr.0.value()); + + // TODO: Run the egraph simplifier. + // TODO: Incremental? + // TODO: Slicing? + let plan = sparrow_compiler::plan::extract_plan_proto( + &self.data_context, + expr, + // TODO: Configure per-entity behavior. + PerEntityBehavior::Final, + primary_grouping, + primary_grouping_key_type, + ) + .into_report() + .change_context(Error::Compile)?; + + // Switch to the Tokio async pool. This seems gross. + // Create the runtime. + // + // TODO: Figure out how to asynchronously pass results back to Python? + let rt = tokio::runtime::Runtime::new() + .into_report() + .change_context(Error::Execute)?; + + // Spawn the root task + rt.block_on(async move { + let (output_tx, output_rx) = tokio::sync::mpsc::channel(10); + + let destination = Destination::Channel(output_tx); + let data_context = self.data_context.clone(); + let options = ExecutionOptions::default(); + + // Hacky. Use the existing execution logic. This weird things with downloading checkpoints, etc. + let mut results = + sparrow_runtime::execute::execute_new(plan, destination, data_context, options) + .await + .change_context(Error::Execute)? + .boxed(); + + // Hacky. Try to get the last response so we can see if there are any errors, etc. + let mut _last = None; + while let Some(response) = results.try_next().await.change_context(Error::Execute)? { + _last = Some(response); + } + + let batches: Vec<_> = tokio_stream::wrappers::ReceiverStream::new(output_rx) + .collect() + .await; + + // Hacky: Assume we produce at least one batch. + // New execution plans contain the schema ref which cleans this up. + let schema = batches[0].schema(); + let batch = arrow_select::concat::concat_batches(&schema, &batches) + .into_report() + .change_context(Error::Execute)?; + Ok(batch) + }) + } + + pub(super) fn hacky_table_mut( + &mut self, + table_id: TableId, + ) -> &mut sparrow_compiler::TableInfo { + self.data_context.table_info_mut(table_id).unwrap() + } } #[static_init::dynamic] @@ -195,7 +283,7 @@ mod tests { use super::*; #[test] - fn session_test() { + fn session_compilation_test() { let mut session = Session::default(); let schema = Arc::new(Schema::new(vec![ diff --git a/crates/sparrow-session/src/table.rs b/crates/sparrow-session/src/table.rs index c4f7849db..eddd7e0b0 100644 --- a/crates/sparrow-session/src/table.rs +++ b/crates/sparrow-session/src/table.rs @@ -5,14 +5,15 @@ use arrow_array::RecordBatch; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use error_stack::ResultExt; use sparrow_compiler::TableInfo; +use sparrow_plan::TableId; use sparrow_runtime::preparer::Preparer; -use crate::{Error, Expr}; +use crate::{Error, Expr, Session}; pub struct Table { + table_id: TableId, pub expr: Expr, preparer: Preparer, - data: RecordBatch, } impl Table { @@ -29,14 +30,14 @@ impl Table { table_info.config().time_column_name.clone(), table_info.config().subsort_column_name.clone(), table_info.config().group_column_name.clone(), - prepared_schema.clone(), + prepared_schema, prepare_hash, ); Self { + table_id: table_info.table_id(), expr, preparer, - data: RecordBatch::new_empty(prepared_schema), } } @@ -44,20 +45,22 @@ impl Table { self.preparer.schema() } - pub fn add_data(&mut self, batch: RecordBatch) -> error_stack::Result<(), Error> { + pub fn add_data( + &mut self, + session: &mut Session, + batch: RecordBatch, + ) -> error_stack::Result<(), Error> { let prepared = self .preparer .prepare_batch(batch) .change_context(Error::Prepare)?; // TODO: Merge the data in. - assert_eq!(self.data.num_rows(), 0); - self.data = prepared; - Ok(()) - } + let table_info = session.hacky_table_mut(self.table_id); - pub fn data(&self) -> &RecordBatch { - &self.data + assert!(table_info.in_memory.is_none()); + table_info.in_memory = Some(prepared); + Ok(()) } } diff --git a/sparrow-py/Cargo.lock b/sparrow-py/Cargo.lock index 6fbb70c29..ad40c73b8 100644 --- a/sparrow-py/Cargo.lock +++ b/sparrow-py/Cargo.lock @@ -3859,15 +3859,20 @@ version = "0.10.0" dependencies = [ "arrow-array", "arrow-schema", + "arrow-select", "derive_more", "error-stack", + "futures", "itertools 0.11.0", "smallvec", "sparrow-api", "sparrow-compiler", + "sparrow-plan", "sparrow-runtime", "sparrow-syntax", "static_init", + "tokio", + "tokio-stream", "uuid 1.4.1", ] diff --git a/sparrow-py/pysrc/sparrow_py/_ffi.pyi b/sparrow-py/pysrc/sparrow_py/_ffi.pyi index 71cc3bb69..83862551f 100644 --- a/sparrow-py/pysrc/sparrow_py/_ffi.pyi +++ b/sparrow-py/pysrc/sparrow_py/_ffi.pyi @@ -15,6 +15,7 @@ class Expr: def data_type_string(self) -> str: ... def equivalent(self, other: Expr) -> bool: ... def session(self) -> Session: ... + def execute(self) -> pa.RecordBatch: ... def call_udf(udf: Udf, result_type: pa.DataType, *args: pa.Array) -> pa.Array: ... @@ -31,7 +32,4 @@ class Table(Expr): ) -> None: ... @property def name(self) -> str: ... - def add_pyarrow(self, data: pa.RecordBatch) -> None: ... - - def prepared_data(self) -> pa.RecordBatch: ... \ No newline at end of file diff --git a/sparrow-py/pysrc/sparrow_py/expr.py b/sparrow-py/pysrc/sparrow_py/expr.py index 0104e2da8..c6b66a07d 100644 --- a/sparrow-py/pysrc/sparrow_py/expr.py +++ b/sparrow-py/pysrc/sparrow_py/expr.py @@ -8,6 +8,7 @@ from typing import Union from typing import final +import pandas as pd import pyarrow as pa import sparrow_py._ffi as _ffi @@ -49,16 +50,24 @@ def call(name: str, *args: Arg) -> "Expr": args : list[Expr] List of arguments to the expression. + Returns + ------- + Expression representing the given operation applied to the arguments. + Raises ------ + # noqa: DAR401 _augment_error TypeError If the argument types are invalid for the given function. + ValueError + If the argument values are invalid for the given function. """ ffi_args = [arg._ffi_expr if isinstance(arg, Expr) else arg for arg in args] session = next(arg._ffi_expr.session() for arg in args if isinstance(arg, Expr)) try: return Expr(_ffi.Expr(session=session, operation=name, args=ffi_args)) except TypeError as e: + # noqa: DAR401 raise _augment_error(args, TypeError(str(e))) except ValueError as e: raise _augment_error(args, ValueError(str(e))) @@ -154,6 +163,13 @@ def __getattr__(self, name: str) -> "Expr": ------- Expr Expression referencing the given field. + + Raises + ------ + AttributeError + When the base is a struct but the field is not found. + TypeError + When the base is not a struct. """ # It's easy for this method to cause infinite recursion, if anything # it references on `self` isn't defined. To try to avoid this, we only @@ -228,12 +244,17 @@ def __getitem__(self, key: Arg) -> "Expr": Parameters ---------- - key : Expr + key : Arg The key to index into the expression. Returns ------- Expression accessing the given index. + + Raises + ------ + TypeError + When the expression is not a struct, list, or map. """ data_type = self.data_type if isinstance(data_type, pa.StructType): @@ -244,3 +265,7 @@ def __getitem__(self, key: Arg) -> "Expr": return Expr.call("get_list", self, key) else: raise TypeError(f"Cannot index into {data_type}") + + def execute(self) -> pd.DataFrame: + """Execute the expression.""" + return self._ffi_expr.execute().to_pandas() diff --git a/sparrow-py/pytests/math_test.py b/sparrow-py/pytests/math_test.py new file mode 100644 index 000000000..c647b2eaa --- /dev/null +++ b/sparrow-py/pytests/math_test.py @@ -0,0 +1,67 @@ +"""Tests for the Kaskada query builder.""" +import random +import sys +from io import StringIO + +import pandas as pd +import pyarrow as pa +import pytest +from sparrow_py import Session +from sparrow_py import Table +from sparrow_py import math + + +@pytest.fixture +def session() -> Session: + """Create a session for testing.""" + session = Session() + return session + + +@pytest.fixture +def table_int64(session: Session) -> Table: + """Create an empty table for testing.""" + schema = pa.schema( + [ + pa.field("time", pa.string(), nullable=False), + pa.field("key", pa.string(), nullable=False), + pa.field("m", pa.int64()), + pa.field("n", pa.int64()), + ] + ) + return Table(session, "table1", "time", "key", schema) + + +def read_csv(csv_string: str, **kwargs) -> pd.DataFrame: + """Read CSV from a string.""" + return pd.read_csv(StringIO(csv_string), dtype_backend="pyarrow", **kwargs) + + +def test_read_table(session, table_int64) -> None: + input = read_csv( + "\n".join( + [ + "time,subsort,key,m,n", + "1996-12-19T16:39:57-08:00,0,A,5,10", + "1996-12-19T16:39:58-08:00,0,B,24,3", + "1996-12-19T16:39:59-08:00,0,A,17,6", + "1996-12-19T16:40:00-08:00,0,A,,9", + "1996-12-19T16:40:01-08:00,0,A,12,", + "1996-12-19T16:40:02-08:00,0,A,,", + ] + ), + dtype={ + "m": "Int64", + "n": "Int64", + }, + ) + table_int64.add_data(input) + + # TODO: Options for outputting to different destinations (eg., to CSV). + # TODO: Allow running a single expression (eg., without field names) + + # result = (table_int64.m + tableint64.n).execute() + result = table_int64.execute() + pd.testing.assert_series_equal(input["time"], result["time"], check_dtype=False) + pd.testing.assert_series_equal(input["m"], result["m"], check_dtype=False) + pd.testing.assert_series_equal(input["n"], result["n"], check_dtype=False) diff --git a/sparrow-py/pytests/table_test.py b/sparrow-py/pytests/table_test.py index 8de971b23..39e6b197f 100644 --- a/sparrow-py/pytests/table_test.py +++ b/sparrow-py/pytests/table_test.py @@ -94,9 +94,9 @@ def test_add_dataframe(session, dataset) -> None: ] ) table = Table(session, "table1", "time", "key", schema) - assert table._ffi_table.prepared_data().num_rows == 0 + # assert table._ffi_table.prepared_data().num_rows == 0 table.add_data(dataset) - assert table._ffi_table.prepared_data().num_rows == len(dataset) - prepared = table._ffi_table.prepared_data().to_pandas() - assert prepared['_time'].is_monotonic_increasing + # assert table._ffi_table.prepared_data().num_rows == len(dataset) + prepared = table.execute() + assert prepared["_time"].is_monotonic_increasing diff --git a/sparrow-py/src/expr.rs b/sparrow-py/src/expr.rs index c80f0509a..94acef22d 100644 --- a/sparrow-py/src/expr.rs +++ b/sparrow-py/src/expr.rs @@ -1,3 +1,4 @@ +use crate::error::Error; use crate::session::Session; use arrow::pyarrow::ToPyArrow; use pyo3::exceptions::{PyRuntimeError, PyValueError}; @@ -59,8 +60,14 @@ impl Expr { self.session.clone() } + fn execute(&self, py: Python<'_>) -> Result { + let session = self.session.rust_session()?; + let batches = session.execute(&self.rust_expr)?; + Ok(batches.to_pyarrow(py)?) + } + /// Return the `pyarrow` type of the resulting expression. - fn data_type(&self, py: Python) -> PyResult> { + fn data_type(&self, py: Python<'_>) -> PyResult> { match self.rust_expr.data_type() { Some(t) => Ok(Some(t.to_pyarrow(py)?)), _ => Ok(None), diff --git a/sparrow-py/src/table.rs b/sparrow-py/src/table.rs index 31182f416..9abc87dde 100644 --- a/sparrow-py/src/table.rs +++ b/sparrow-py/src/table.rs @@ -1,9 +1,9 @@ +use std::ops::DerefMut; use std::sync::Arc; use arrow::datatypes::Schema; -use arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; +use arrow::pyarrow::{FromPyArrow, PyArrowType}; use arrow::record_batch::RecordBatch; -use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use sparrow_session::Table as RustTable; @@ -16,6 +16,7 @@ pub(crate) struct Table { #[pyo3(get)] name: String, rust_table: RustTable, + session: Session, } #[pymethods] @@ -44,7 +45,11 @@ impl Table { )?; let rust_expr = rust_table.expr.clone(); - let table = Table { name, rust_table }; + let table = Table { + name, + rust_table, + session: session.clone(), + }; let expr = Expr { rust_expr, session }; Ok((table, expr)) } @@ -58,12 +63,9 @@ impl Table { /// - Python generators? /// TODO: Error handling fn add_pyarrow(&mut self, data: &PyAny) -> Result<()> { + let mut session = self.session.rust_session()?; let data = RecordBatch::from_pyarrow(data)?; - self.rust_table.add_data(data)?; + self.rust_table.add_data(session.deref_mut(), data)?; Ok(()) } - - fn prepared_data(&self, py: Python) -> PyResult { - self.rust_table.data().to_pyarrow(py) - } }