Skip to content

Commit

Permalink
[FEAT] implement range operation and data streaming (#3267)
Browse files Browse the repository at this point in the history
Co-authored-by: Cory Grinstead <cory.grinstead@gmail.com>
  • Loading branch information
andrewgazelka and universalmind303 authored Nov 13, 2024
1 parent 62a936e commit 1a4d259
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 8 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ tonic = "0.12.3"
tracing-subscriber = {version = "0.3.18", features = ["env-filter"]}
tracing-tracy = "0.11.3"
uuid = {version = "1.10.0", features = ["v4"]}
arrow2.workspace = true
daft-core.workspace = true
daft-schema.workspace = true
daft-table.workspace = true
spark-connect.workspace = true
tracing.workspace = true

Expand Down
104 changes: 104 additions & 0 deletions src/daft-connect/src/command.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use std::future::ready;

use arrow2::io::ipc::write::StreamWriter;
use daft_table::Table;
use eyre::Context;
use futures::stream;
use spark_connect::{
execute_plan_response::{ArrowBatch, ResponseType, ResultComplete},
spark_connect_service_server::SparkConnectService,
ExecutePlanResponse, Relation,
};
use tonic::Status;
use uuid::Uuid;

use crate::{convert::convert_data, DaftSparkConnectService, Session};

type DaftStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

pub struct PlanIds {
session: String,
server_side_session: String,
operation: String,
}

impl PlanIds {
pub fn gen_response(&self, table: &Table) -> eyre::Result<ExecutePlanResponse> {
let mut data = Vec::new();

let mut writer = StreamWriter::new(
&mut data,
arrow2::io::ipc::write::WriteOptions { compression: None },
);

let row_count = table.num_rows();

let schema = table
.schema
.to_arrow()
.wrap_err("Failed to convert Daft schema to Arrow schema")?;

writer
.start(&schema, None)
.wrap_err("Failed to start Arrow stream writer with schema")?;

let arrays = table.get_inner_arrow_arrays().collect();
let chunk = arrow2::chunk::Chunk::new(arrays);

writer
.write(&chunk, None)
.wrap_err("Failed to write Arrow chunk to stream writer")?;

let response = ExecutePlanResponse {
session_id: self.session.to_string(),
server_side_session_id: self.server_side_session.to_string(),
operation_id: self.operation.to_string(),
response_id: Uuid::new_v4().to_string(), // todo: implement this
metrics: None, // todo: implement this
observed_metrics: vec![],
schema: None,
response_type: Some(ResponseType::ArrowBatch(ArrowBatch {
row_count: row_count as i64,
data,
start_offset: None,
})),
};

Ok(response)
}
}

impl Session {
pub async fn handle_root_command(
&self,
command: Relation,
operation_id: String,
) -> Result<DaftStream, Status> {
use futures::{StreamExt, TryStreamExt};

let context = PlanIds {
session: self.client_side_session_id().to_string(),
server_side_session: self.server_side_session_id().to_string(),
operation: operation_id.clone(),
};

let finished = ExecutePlanResponse {
session_id: self.client_side_session_id().to_string(),
server_side_session_id: self.server_side_session_id().to_string(),
operation_id,
response_id: Uuid::new_v4().to_string(),
metrics: None,
observed_metrics: vec![],
schema: None,
response_type: Some(ResponseType::ResultComplete(ResultComplete {})),
};

let stream = convert_data(command, &context)
.map_err(|e| Status::internal(e.to_string()))?
.chain(stream::once(ready(Ok(finished))));

Ok(Box::pin(
stream.map_err(|e| Status::internal(e.to_string())),
))
}
}
6 changes: 6 additions & 0 deletions src/daft-connect/src/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
mod data_conversion;
mod formatting;
mod schema_conversion;

pub use data_conversion::convert_data;
pub use schema_conversion::connect_schema;
61 changes: 61 additions & 0 deletions src/daft-connect/src/convert/data_conversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//! Relation handling for Spark Connect protocol.
//!
//! A Relation represents a structured dataset or transformation in Spark Connect.
//! It can be either a base relation (direct data source) or derived relation
//! (result of operations on other relations).
//!
//! The protocol represents relations as trees of operations where:
//! - Each node is a Relation with metadata and an operation type
//! - Operations can reference other relations, forming a DAG
//! - The tree describes how to derive the final result
//!
//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age
//!
//! ```text
//! Aggregate (grouping by age)
//! ↳ Filter (department = 'Engineering')
//! ↳ Read (employees table)
//! ```
//!
//! Relations abstract away:
//! - Physical storage details
//! - Distributed computation
//! - Query optimization
//! - Data source specifics
//!
//! This allows Spark to optimize and execute queries efficiently across a cluster
//! while providing a consistent API regardless of the underlying data source.
//! ```mermaid
//!
//! ```
use eyre::{eyre, Context};
use futures::Stream;
use spark_connect::{relation::RelType, ExecutePlanResponse, Relation};
use tracing::trace;

use crate::convert::formatting::RelTypeExt;

mod range;
use range::range;

use crate::command::PlanIds;

pub fn convert_data(
plan: Relation,
context: &PlanIds,
) -> eyre::Result<impl Stream<Item = eyre::Result<ExecutePlanResponse>> + Unpin> {
// First check common fields if needed
if let Some(common) = &plan.common {
// contains metadata shared across all relation types
// Log or handle common fields if necessary
trace!("Processing relation with plan_id: {:?}", common.plan_id);
}

let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?;

match rel_type {
RelType::Range(input) => range(input, context).wrap_err("parsing Range"),
other => Err(eyre!("Unsupported top-level relation: {}", other.name())),
}
}
48 changes: 48 additions & 0 deletions src/daft-connect/src/convert/data_conversion/range.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::future::ready;

use daft_core::prelude::Series;
use daft_schema::prelude::Schema;
use daft_table::Table;
use eyre::{ensure, Context};
use futures::{stream, Stream};
use spark_connect::{ExecutePlanResponse, Range};

use crate::command::PlanIds;

pub fn range(
range: Range,
channel: &PlanIds,
) -> eyre::Result<impl Stream<Item = eyre::Result<ExecutePlanResponse>> + Unpin> {
let Range {
start,
end,
step,
num_partitions,
} = range;

let start = start.unwrap_or(0);

ensure!(num_partitions.is_none(), "num_partitions is not supported");

let step = usize::try_from(step).wrap_err("step must be a positive integer")?;
ensure!(step > 0, "step must be greater than 0");

let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect();
let len = arrow_array.len();

let singleton_series = Series::try_from((
"range",
Box::new(arrow_array) as Box<dyn arrow2::array::Array>,
))
.wrap_err("creating singleton series")?;

let singleton_table = Table::new_with_size(
Schema::new(vec![singleton_series.field().clone()])?,
vec![singleton_series],
len,
)?;

let response = channel.gen_response(&singleton_table)?;

Ok(stream::once(ready(Ok(response))))
}
69 changes: 69 additions & 0 deletions src/daft-connect/src/convert/formatting.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use spark_connect::relation::RelType;

/// Extension trait for RelType to add a `name` method.
pub trait RelTypeExt {
/// Returns the name of the RelType as a string.
fn name(&self) -> &'static str;
}

impl RelTypeExt for RelType {
fn name(&self) -> &'static str {
match self {
Self::Read(_) => "Read",
Self::Project(_) => "Project",
Self::Filter(_) => "Filter",
Self::Join(_) => "Join",
Self::SetOp(_) => "SetOp",
Self::Sort(_) => "Sort",
Self::Limit(_) => "Limit",
Self::Aggregate(_) => "Aggregate",
Self::Sql(_) => "Sql",
Self::LocalRelation(_) => "LocalRelation",
Self::Sample(_) => "Sample",
Self::Offset(_) => "Offset",
Self::Deduplicate(_) => "Deduplicate",
Self::Range(_) => "Range",
Self::SubqueryAlias(_) => "SubqueryAlias",
Self::Repartition(_) => "Repartition",
Self::ToDf(_) => "ToDf",
Self::WithColumnsRenamed(_) => "WithColumnsRenamed",
Self::ShowString(_) => "ShowString",
Self::Drop(_) => "Drop",
Self::Tail(_) => "Tail",
Self::WithColumns(_) => "WithColumns",
Self::Hint(_) => "Hint",
Self::Unpivot(_) => "Unpivot",
Self::ToSchema(_) => "ToSchema",
Self::RepartitionByExpression(_) => "RepartitionByExpression",
Self::MapPartitions(_) => "MapPartitions",
Self::CollectMetrics(_) => "CollectMetrics",
Self::Parse(_) => "Parse",
Self::GroupMap(_) => "GroupMap",
Self::CoGroupMap(_) => "CoGroupMap",
Self::WithWatermark(_) => "WithWatermark",
Self::ApplyInPandasWithState(_) => "ApplyInPandasWithState",
Self::HtmlString(_) => "HtmlString",
Self::CachedLocalRelation(_) => "CachedLocalRelation",
Self::CachedRemoteRelation(_) => "CachedRemoteRelation",
Self::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction",
Self::AsOfJoin(_) => "AsOfJoin",
Self::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource",
Self::WithRelations(_) => "WithRelations",
Self::Transpose(_) => "Transpose",
Self::FillNa(_) => "FillNa",
Self::DropNa(_) => "DropNa",
Self::Replace(_) => "Replace",
Self::Summary(_) => "Summary",
Self::Crosstab(_) => "Crosstab",
Self::Describe(_) => "Describe",
Self::Cov(_) => "Cov",
Self::Corr(_) => "Corr",
Self::ApproxQuantile(_) => "ApproxQuantile",
Self::FreqItems(_) => "FreqItems",
Self::SampleBy(_) => "SampleBy",
Self::Catalog(_) => "Catalog",
Self::Extension(_) => "Extension",
Self::Unknown(_) => "Unknown",
}
}
}
56 changes: 56 additions & 0 deletions src/daft-connect/src/convert/schema_conversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use spark_connect::{
data_type::{Kind, Long, Struct, StructField},
relation::RelType,
DataType, Relation,
};

#[tracing::instrument(skip_all)]
pub fn connect_schema(input: Relation) -> Result<DataType, tonic::Status> {
if input.common.is_some() {
tracing::warn!("We do not currently look at common fields");
}

let result = match input
.rel_type
.ok_or_else(|| tonic::Status::internal("rel_type is None"))?
{
RelType::Range(spark_connect::Range { num_partitions, .. }) => {
if num_partitions.is_some() {
return Err(tonic::Status::unimplemented(
"num_partitions is not supported",
));
}

let long = Long {
type_variation_reference: 0,
};

let id_field = StructField {
name: "id".to_string(),
data_type: Some(DataType {
kind: Some(Kind::Long(long)),
}),
nullable: false,
metadata: None,
};

let fields = vec![id_field];

let strct = Struct {
fields,
type_variation_reference: 0,
};

DataType {
kind: Some(Kind::Struct(strct)),
}
}
other => {
return Err(tonic::Status::unimplemented(format!(
"Unsupported relation type: {other:?}"
)))
}
};

Ok(result)
}
2 changes: 1 addition & 1 deletion src/daft-connect/src/err.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[macro_export]
macro_rules! invalid_argument {
macro_rules! invalid_argument_err {
($arg: tt) => {{
let msg = format!($arg);
Err(::tonic::Status::invalid_argument(msg))
Expand Down
Loading

0 comments on commit 1a4d259

Please sign in to comment.