-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] implement range operation and data streaming (#3267)
Co-authored-by: Cory Grinstead <cory.grinstead@gmail.com>
- Loading branch information
1 parent
62a936e
commit 1a4d259
Showing
11 changed files
with
524 additions
and
8 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
use std::future::ready; | ||
|
||
use arrow2::io::ipc::write::StreamWriter; | ||
use daft_table::Table; | ||
use eyre::Context; | ||
use futures::stream; | ||
use spark_connect::{ | ||
execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, | ||
spark_connect_service_server::SparkConnectService, | ||
ExecutePlanResponse, Relation, | ||
}; | ||
use tonic::Status; | ||
use uuid::Uuid; | ||
|
||
use crate::{convert::convert_data, DaftSparkConnectService, Session}; | ||
|
||
type DaftStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream; | ||
|
||
pub struct PlanIds { | ||
session: String, | ||
server_side_session: String, | ||
operation: String, | ||
} | ||
|
||
impl PlanIds { | ||
pub fn gen_response(&self, table: &Table) -> eyre::Result<ExecutePlanResponse> { | ||
let mut data = Vec::new(); | ||
|
||
let mut writer = StreamWriter::new( | ||
&mut data, | ||
arrow2::io::ipc::write::WriteOptions { compression: None }, | ||
); | ||
|
||
let row_count = table.num_rows(); | ||
|
||
let schema = table | ||
.schema | ||
.to_arrow() | ||
.wrap_err("Failed to convert Daft schema to Arrow schema")?; | ||
|
||
writer | ||
.start(&schema, None) | ||
.wrap_err("Failed to start Arrow stream writer with schema")?; | ||
|
||
let arrays = table.get_inner_arrow_arrays().collect(); | ||
let chunk = arrow2::chunk::Chunk::new(arrays); | ||
|
||
writer | ||
.write(&chunk, None) | ||
.wrap_err("Failed to write Arrow chunk to stream writer")?; | ||
|
||
let response = ExecutePlanResponse { | ||
session_id: self.session.to_string(), | ||
server_side_session_id: self.server_side_session.to_string(), | ||
operation_id: self.operation.to_string(), | ||
response_id: Uuid::new_v4().to_string(), // todo: implement this | ||
metrics: None, // todo: implement this | ||
observed_metrics: vec![], | ||
schema: None, | ||
response_type: Some(ResponseType::ArrowBatch(ArrowBatch { | ||
row_count: row_count as i64, | ||
data, | ||
start_offset: None, | ||
})), | ||
}; | ||
|
||
Ok(response) | ||
} | ||
} | ||
|
||
impl Session { | ||
pub async fn handle_root_command( | ||
&self, | ||
command: Relation, | ||
operation_id: String, | ||
) -> Result<DaftStream, Status> { | ||
use futures::{StreamExt, TryStreamExt}; | ||
|
||
let context = PlanIds { | ||
session: self.client_side_session_id().to_string(), | ||
server_side_session: self.server_side_session_id().to_string(), | ||
operation: operation_id.clone(), | ||
}; | ||
|
||
let finished = ExecutePlanResponse { | ||
session_id: self.client_side_session_id().to_string(), | ||
server_side_session_id: self.server_side_session_id().to_string(), | ||
operation_id, | ||
response_id: Uuid::new_v4().to_string(), | ||
metrics: None, | ||
observed_metrics: vec![], | ||
schema: None, | ||
response_type: Some(ResponseType::ResultComplete(ResultComplete {})), | ||
}; | ||
|
||
let stream = convert_data(command, &context) | ||
.map_err(|e| Status::internal(e.to_string()))? | ||
.chain(stream::once(ready(Ok(finished)))); | ||
|
||
Ok(Box::pin( | ||
stream.map_err(|e| Status::internal(e.to_string())), | ||
)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
mod data_conversion; | ||
mod formatting; | ||
mod schema_conversion; | ||
|
||
pub use data_conversion::convert_data; | ||
pub use schema_conversion::connect_schema; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
//! Relation handling for Spark Connect protocol. | ||
//! | ||
//! A Relation represents a structured dataset or transformation in Spark Connect. | ||
//! It can be either a base relation (direct data source) or derived relation | ||
//! (result of operations on other relations). | ||
//! | ||
//! The protocol represents relations as trees of operations where: | ||
//! - Each node is a Relation with metadata and an operation type | ||
//! - Operations can reference other relations, forming a DAG | ||
//! - The tree describes how to derive the final result | ||
//! | ||
//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age | ||
//! | ||
//! ```text | ||
//! Aggregate (grouping by age) | ||
//! ↳ Filter (department = 'Engineering') | ||
//! ↳ Read (employees table) | ||
//! ``` | ||
//! | ||
//! Relations abstract away: | ||
//! - Physical storage details | ||
//! - Distributed computation | ||
//! - Query optimization | ||
//! - Data source specifics | ||
//! | ||
//! This allows Spark to optimize and execute queries efficiently across a cluster | ||
//! while providing a consistent API regardless of the underlying data source. | ||
//! ```mermaid | ||
//! | ||
//! ``` | ||
use eyre::{eyre, Context}; | ||
use futures::Stream; | ||
use spark_connect::{relation::RelType, ExecutePlanResponse, Relation}; | ||
use tracing::trace; | ||
|
||
use crate::convert::formatting::RelTypeExt; | ||
|
||
mod range; | ||
use range::range; | ||
|
||
use crate::command::PlanIds; | ||
|
||
pub fn convert_data( | ||
plan: Relation, | ||
context: &PlanIds, | ||
) -> eyre::Result<impl Stream<Item = eyre::Result<ExecutePlanResponse>> + Unpin> { | ||
// First check common fields if needed | ||
if let Some(common) = &plan.common { | ||
// contains metadata shared across all relation types | ||
// Log or handle common fields if necessary | ||
trace!("Processing relation with plan_id: {:?}", common.plan_id); | ||
} | ||
|
||
let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?; | ||
|
||
match rel_type { | ||
RelType::Range(input) => range(input, context).wrap_err("parsing Range"), | ||
other => Err(eyre!("Unsupported top-level relation: {}", other.name())), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
use std::future::ready; | ||
|
||
use daft_core::prelude::Series; | ||
use daft_schema::prelude::Schema; | ||
use daft_table::Table; | ||
use eyre::{ensure, Context}; | ||
use futures::{stream, Stream}; | ||
use spark_connect::{ExecutePlanResponse, Range}; | ||
|
||
use crate::command::PlanIds; | ||
|
||
pub fn range( | ||
range: Range, | ||
channel: &PlanIds, | ||
) -> eyre::Result<impl Stream<Item = eyre::Result<ExecutePlanResponse>> + Unpin> { | ||
let Range { | ||
start, | ||
end, | ||
step, | ||
num_partitions, | ||
} = range; | ||
|
||
let start = start.unwrap_or(0); | ||
|
||
ensure!(num_partitions.is_none(), "num_partitions is not supported"); | ||
|
||
let step = usize::try_from(step).wrap_err("step must be a positive integer")?; | ||
ensure!(step > 0, "step must be greater than 0"); | ||
|
||
let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect(); | ||
let len = arrow_array.len(); | ||
|
||
let singleton_series = Series::try_from(( | ||
"range", | ||
Box::new(arrow_array) as Box<dyn arrow2::array::Array>, | ||
)) | ||
.wrap_err("creating singleton series")?; | ||
|
||
let singleton_table = Table::new_with_size( | ||
Schema::new(vec![singleton_series.field().clone()])?, | ||
vec![singleton_series], | ||
len, | ||
)?; | ||
|
||
let response = channel.gen_response(&singleton_table)?; | ||
|
||
Ok(stream::once(ready(Ok(response)))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use spark_connect::relation::RelType; | ||
|
||
/// Extension trait for RelType to add a `name` method. | ||
pub trait RelTypeExt { | ||
/// Returns the name of the RelType as a string. | ||
fn name(&self) -> &'static str; | ||
} | ||
|
||
impl RelTypeExt for RelType { | ||
fn name(&self) -> &'static str { | ||
match self { | ||
Self::Read(_) => "Read", | ||
Self::Project(_) => "Project", | ||
Self::Filter(_) => "Filter", | ||
Self::Join(_) => "Join", | ||
Self::SetOp(_) => "SetOp", | ||
Self::Sort(_) => "Sort", | ||
Self::Limit(_) => "Limit", | ||
Self::Aggregate(_) => "Aggregate", | ||
Self::Sql(_) => "Sql", | ||
Self::LocalRelation(_) => "LocalRelation", | ||
Self::Sample(_) => "Sample", | ||
Self::Offset(_) => "Offset", | ||
Self::Deduplicate(_) => "Deduplicate", | ||
Self::Range(_) => "Range", | ||
Self::SubqueryAlias(_) => "SubqueryAlias", | ||
Self::Repartition(_) => "Repartition", | ||
Self::ToDf(_) => "ToDf", | ||
Self::WithColumnsRenamed(_) => "WithColumnsRenamed", | ||
Self::ShowString(_) => "ShowString", | ||
Self::Drop(_) => "Drop", | ||
Self::Tail(_) => "Tail", | ||
Self::WithColumns(_) => "WithColumns", | ||
Self::Hint(_) => "Hint", | ||
Self::Unpivot(_) => "Unpivot", | ||
Self::ToSchema(_) => "ToSchema", | ||
Self::RepartitionByExpression(_) => "RepartitionByExpression", | ||
Self::MapPartitions(_) => "MapPartitions", | ||
Self::CollectMetrics(_) => "CollectMetrics", | ||
Self::Parse(_) => "Parse", | ||
Self::GroupMap(_) => "GroupMap", | ||
Self::CoGroupMap(_) => "CoGroupMap", | ||
Self::WithWatermark(_) => "WithWatermark", | ||
Self::ApplyInPandasWithState(_) => "ApplyInPandasWithState", | ||
Self::HtmlString(_) => "HtmlString", | ||
Self::CachedLocalRelation(_) => "CachedLocalRelation", | ||
Self::CachedRemoteRelation(_) => "CachedRemoteRelation", | ||
Self::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction", | ||
Self::AsOfJoin(_) => "AsOfJoin", | ||
Self::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource", | ||
Self::WithRelations(_) => "WithRelations", | ||
Self::Transpose(_) => "Transpose", | ||
Self::FillNa(_) => "FillNa", | ||
Self::DropNa(_) => "DropNa", | ||
Self::Replace(_) => "Replace", | ||
Self::Summary(_) => "Summary", | ||
Self::Crosstab(_) => "Crosstab", | ||
Self::Describe(_) => "Describe", | ||
Self::Cov(_) => "Cov", | ||
Self::Corr(_) => "Corr", | ||
Self::ApproxQuantile(_) => "ApproxQuantile", | ||
Self::FreqItems(_) => "FreqItems", | ||
Self::SampleBy(_) => "SampleBy", | ||
Self::Catalog(_) => "Catalog", | ||
Self::Extension(_) => "Extension", | ||
Self::Unknown(_) => "Unknown", | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use spark_connect::{ | ||
data_type::{Kind, Long, Struct, StructField}, | ||
relation::RelType, | ||
DataType, Relation, | ||
}; | ||
|
||
#[tracing::instrument(skip_all)] | ||
pub fn connect_schema(input: Relation) -> Result<DataType, tonic::Status> { | ||
if input.common.is_some() { | ||
tracing::warn!("We do not currently look at common fields"); | ||
} | ||
|
||
let result = match input | ||
.rel_type | ||
.ok_or_else(|| tonic::Status::internal("rel_type is None"))? | ||
{ | ||
RelType::Range(spark_connect::Range { num_partitions, .. }) => { | ||
if num_partitions.is_some() { | ||
return Err(tonic::Status::unimplemented( | ||
"num_partitions is not supported", | ||
)); | ||
} | ||
|
||
let long = Long { | ||
type_variation_reference: 0, | ||
}; | ||
|
||
let id_field = StructField { | ||
name: "id".to_string(), | ||
data_type: Some(DataType { | ||
kind: Some(Kind::Long(long)), | ||
}), | ||
nullable: false, | ||
metadata: None, | ||
}; | ||
|
||
let fields = vec![id_field]; | ||
|
||
let strct = Struct { | ||
fields, | ||
type_variation_reference: 0, | ||
}; | ||
|
||
DataType { | ||
kind: Some(Kind::Struct(strct)), | ||
} | ||
} | ||
other => { | ||
return Err(tonic::Status::unimplemented(format!( | ||
"Unsupported relation type: {other:?}" | ||
))) | ||
} | ||
}; | ||
|
||
Ok(result) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.