-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(connect): better error propagation & handling (#3675)
depends on #3666 see here for proper diff universalmind303/Daft@rust-ray-exec...universalmind303:Daft:error-messages
- Loading branch information
1 parent
432714d
commit 809e411
Showing
11 changed files
with
494 additions
and
592 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
use dashmap::DashMap; | ||
use spark_connect::{ | ||
command::CommandType, plan::OpType, spark_connect_service_server::SparkConnectService, | ||
AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, | ||
ArtifactStatusesRequest, ArtifactStatusesResponse, ConfigRequest, ConfigResponse, | ||
ExecutePlanRequest, ExecutePlanResponse, FetchErrorDetailsRequest, FetchErrorDetailsResponse, | ||
InterruptRequest, InterruptResponse, Plan, ReattachExecuteRequest, ReleaseExecuteRequest, | ||
ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, | ||
}; | ||
use tonic::{Request, Response, Status}; | ||
use tracing::debug; | ||
use uuid::Uuid; | ||
|
||
use crate::{ | ||
display::SparkDisplay, | ||
invalid_argument_err, not_yet_implemented, | ||
response_builder::ResponseBuilder, | ||
session::Session, | ||
translation::{self, SparkAnalyzer}, | ||
util::FromOptionalField, | ||
}; | ||
|
||
#[derive(Default)] | ||
pub struct DaftSparkConnectService { | ||
client_to_session: DashMap<Uuid, Session>, // To track session data | ||
} | ||
|
||
impl DaftSparkConnectService { | ||
fn get_session( | ||
&self, | ||
session_id: &str, | ||
) -> Result<dashmap::mapref::one::RefMut<Uuid, Session>, Status> { | ||
let Ok(uuid) = Uuid::parse_str(session_id) else { | ||
return Err(Status::invalid_argument( | ||
"Invalid session_id format, must be a UUID", | ||
)); | ||
}; | ||
|
||
let res = self | ||
.client_to_session | ||
.entry(uuid) | ||
.or_insert_with(|| Session::new(session_id.to_string())); | ||
|
||
Ok(res) | ||
} | ||
} | ||
|
||
#[tonic::async_trait] | ||
impl SparkConnectService for DaftSparkConnectService { | ||
type ExecutePlanStream = std::pin::Pin< | ||
Box<dyn futures::Stream<Item = Result<ExecutePlanResponse, Status>> + Send + 'static>, | ||
>; | ||
type ReattachExecuteStream = std::pin::Pin< | ||
Box<dyn futures::Stream<Item = Result<ExecutePlanResponse, Status>> + Send + 'static>, | ||
>; | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn execute_plan( | ||
&self, | ||
request: Request<ExecutePlanRequest>, | ||
) -> Result<Response<Self::ExecutePlanStream>, Status> { | ||
let request = request.into_inner(); | ||
|
||
let session = self.get_session(&request.session_id)?; | ||
let operation_id = request | ||
.operation_id | ||
.unwrap_or_else(|| Uuid::new_v4().to_string()); | ||
|
||
let rb = ResponseBuilder::new(&session, operation_id); | ||
|
||
// Proceed with executing the plan... | ||
let plan = request.plan.required("plan")?; | ||
let plan = plan.op_type.required("op_type")?; | ||
|
||
match plan { | ||
OpType::Root(relation) => { | ||
let result = session.execute_command(relation, rb).await?; | ||
Ok(Response::new(result)) | ||
} | ||
OpType::Command(command) => { | ||
let command = command.command_type.required("command_type")?; | ||
|
||
match command { | ||
CommandType::WriteOperation(op) => { | ||
let result = session.execute_write_operation(op, rb).await?; | ||
Ok(Response::new(result)) | ||
} | ||
other => { | ||
return not_yet_implemented!( | ||
"Command type: {}", | ||
command_type_to_str(&other) | ||
) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn config( | ||
&self, | ||
request: Request<ConfigRequest>, | ||
) -> Result<Response<ConfigResponse>, Status> { | ||
let request = request.into_inner(); | ||
|
||
let mut session = self.get_session(&request.session_id)?; | ||
|
||
let operation = request | ||
.operation | ||
.and_then(|op| op.op_type) | ||
.required("operation.op_type")?; | ||
|
||
use spark_connect::config_request::operation::OpType; | ||
|
||
let response = match operation { | ||
OpType::Set(op) => session.set(op), | ||
OpType::Get(op) => session.get(op), | ||
OpType::GetWithDefault(op) => session.get_with_default(op), | ||
OpType::GetOption(op) => session.get_option(op), | ||
OpType::GetAll(op) => session.get_all(op), | ||
OpType::Unset(op) => session.unset(op), | ||
OpType::IsModifiable(op) => session.is_modifiable(op), | ||
}?; | ||
|
||
Ok(Response::new(response)) | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn add_artifacts( | ||
&self, | ||
_request: Request<tonic::Streaming<AddArtifactsRequest>>, | ||
) -> Result<Response<AddArtifactsResponse>, Status> { | ||
not_yet_implemented!("add_artifacts operation") | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn analyze_plan( | ||
&self, | ||
request: Request<AnalyzePlanRequest>, | ||
) -> Result<Response<AnalyzePlanResponse>, Status> { | ||
use spark_connect::analyze_plan_request::*; | ||
let request = request.into_inner(); | ||
|
||
let AnalyzePlanRequest { | ||
session_id, | ||
analyze, | ||
.. | ||
} = request; | ||
|
||
let session = self.get_session(&session_id)?; | ||
let rb = ResponseBuilder::new(&session, Uuid::new_v4().to_string()); | ||
|
||
let analyze = analyze.required("analyze")?; | ||
|
||
match analyze { | ||
Analyze::Schema(Schema { plan }) => { | ||
let Plan { op_type } = plan.required("plan")?; | ||
|
||
let OpType::Root(relation) = op_type.required("op_type")? else { | ||
return invalid_argument_err!("op_type must be Root"); | ||
}; | ||
|
||
let translator = SparkAnalyzer::new(&session); | ||
|
||
let result = match translator.relation_to_spark_schema(relation).await { | ||
Ok(schema) => schema, | ||
Err(e) => { | ||
return invalid_argument_err!( | ||
"Failed to translate relation to schema: {e:?}" | ||
); | ||
} | ||
}; | ||
Ok(Response::new(rb.schema_response(result))) | ||
} | ||
Analyze::DdlParse(DdlParse { ddl_string }) => { | ||
let daft_schema = match daft_sql::sql_schema(&ddl_string) { | ||
Ok(daft_schema) => daft_schema, | ||
Err(e) => return invalid_argument_err!("{e}"), | ||
}; | ||
|
||
let daft_schema = daft_schema.to_struct(); | ||
|
||
let schema = translation::to_spark_datatype(&daft_schema); | ||
|
||
Ok(Response::new(rb.schema_response(schema))) | ||
} | ||
Analyze::TreeString(TreeString { plan, level }) => { | ||
let plan = plan.required("plan")?; | ||
|
||
if let Some(level) = level { | ||
debug!("ignoring tree string level: {level:?}"); | ||
}; | ||
|
||
let OpType::Root(input) = plan.op_type.required("op_type")? else { | ||
return invalid_argument_err!("op_type must be Root"); | ||
}; | ||
|
||
if let Some(common) = &input.common { | ||
if common.origin.is_some() { | ||
debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); | ||
} | ||
} | ||
|
||
let translator = SparkAnalyzer::new(&session); | ||
let plan = Box::pin(translator.to_logical_plan(input)) | ||
.await | ||
.unwrap() | ||
.build(); | ||
|
||
let schema = plan.schema(); | ||
let tree_string = schema.repr_spark_string(); | ||
Ok(Response::new(rb.treestring_response(tree_string))) | ||
} | ||
other => not_yet_implemented!("Analyze '{other:?}'"), | ||
} | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn artifact_status( | ||
&self, | ||
_request: Request<ArtifactStatusesRequest>, | ||
) -> Result<Response<ArtifactStatusesResponse>, Status> { | ||
not_yet_implemented!("artifact_status operation") | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn interrupt( | ||
&self, | ||
_request: Request<InterruptRequest>, | ||
) -> Result<Response<InterruptResponse>, Status> { | ||
not_yet_implemented!("interrupt operation") | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn reattach_execute( | ||
&self, | ||
_request: Request<ReattachExecuteRequest>, | ||
) -> Result<Response<Self::ReattachExecuteStream>, Status> { | ||
not_yet_implemented!("reattach_execute operation") | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn release_execute( | ||
&self, | ||
request: Request<ReleaseExecuteRequest>, | ||
) -> Result<Response<ReleaseExecuteResponse>, Status> { | ||
let request = request.into_inner(); | ||
|
||
let session = self.get_session(&request.session_id)?; | ||
|
||
let response = ReleaseExecuteResponse { | ||
session_id: session.client_side_session_id().to_string(), | ||
server_side_session_id: session.server_side_session_id().to_string(), | ||
operation_id: None, // todo: set but not strictly required | ||
}; | ||
|
||
Ok(Response::new(response)) | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn release_session( | ||
&self, | ||
_request: Request<ReleaseSessionRequest>, | ||
) -> Result<Response<ReleaseSessionResponse>, Status> { | ||
not_yet_implemented!("release_session operation") | ||
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
async fn fetch_error_details( | ||
&self, | ||
_request: Request<FetchErrorDetailsRequest>, | ||
) -> Result<Response<FetchErrorDetailsResponse>, Status> { | ||
not_yet_implemented!("fetch_error_details operation") | ||
} | ||
} | ||
|
||
fn command_type_to_str(cmd_type: &CommandType) -> &str { | ||
match cmd_type { | ||
CommandType::RegisterFunction(_) => "RegisterFunction", | ||
CommandType::WriteOperation(_) => "WriteOperation", | ||
CommandType::CreateDataframeView(_) => "CreateDataframeView", | ||
CommandType::WriteOperationV2(_) => "WriteOperationV2", | ||
CommandType::SqlCommand(_) => "SqlCommand", | ||
CommandType::WriteStreamOperationStart(_) => "WriteStreamOperationStart", | ||
CommandType::StreamingQueryCommand(_) => "StreamingQueryCommand", | ||
CommandType::GetResourcesCommand(_) => "GetResourcesCommand", | ||
CommandType::StreamingQueryManagerCommand(_) => "StreamingQueryManagerCommand", | ||
CommandType::RegisterTableFunction(_) => "RegisterTableFunction", | ||
CommandType::StreamingQueryListenerBusCommand(_) => "StreamingQueryListenerBusCommand", | ||
CommandType::RegisterDataSource(_) => "RegisterDataSource", | ||
CommandType::CreateResourceProfileCommand(_) => "CreateResourceProfileCommand", | ||
CommandType::CheckpointCommand(_) => "CheckpointCommand", | ||
CommandType::RemoveCachedRemoteRelationCommand(_) => "RemoveCachedRemoteRelationCommand", | ||
CommandType::MergeIntoTableCommand(_) => "MergeIntoTableCommand", | ||
CommandType::Extension(_) => "Extension", | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.