Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-42960] [CONNECT] [SS] Add await_termination() and exception() API for Streaming Query in Python #40785

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4d0fcdd
done
WweiL Apr 6, 2023
0ae7e33
add versionchanged to query and readwriter
WweiL Apr 7, 2023
b14fe81
fix conflict, put TODO to correct position
WweiL Apr 7, 2023
17720b7
style
WweiL Apr 10, 2023
1e68a3c
comments
WweiL Apr 10, 2023
dc05be8
address comments, add a new foreachBatch test class, remove all ELLIP…
WweiL Apr 10, 2023
c1674eb
minor
WweiL Apr 10, 2023
23b9c93
minor
WweiL Apr 10, 2023
60ddd01
lint
WweiL Apr 11, 2023
e576821
remove empty line
WweiL Apr 11, 2023
304d01e
wip
WweiL Apr 12, 2023
26e2488
remove several docs in connect readwriter.py and query.py to pass doc…
WweiL Apr 12, 2023
e25f7e6
minor, add back doc tests in module.py
WweiL Apr 12, 2023
aa1d4c2
style
WweiL Apr 12, 2023
e82678e
wip
WweiL Apr 12, 2023
720e3d2
merge unittest
WweiL Apr 12, 2023
1f3ba94
done
WweiL Apr 13, 2023
d8a6e9f
merge master
WweiL Apr 13, 2023
0c2776a
minor
WweiL Apr 13, 2023
da6fc21
address comments
WweiL Apr 14, 2023
ce8615f
fix conflict
WweiL Apr 17, 2023
7a377ec
style
WweiL Apr 17, 2023
be76438
wip
WweiL Apr 18, 2023
e489501
Merge branch 'master' into SPARK-42960-query-cmd-new
WweiL Apr 18, 2023
7045f9d
adderss comments, change exception_message as optional
WweiL Apr 18, 2023
83413c8
lint
WweiL Apr 18, 2023
8e848c9
reformat py
WweiL Apr 18, 2023
ad73db6
done
WweiL Apr 18, 2023
ea3bb35
Merge branch 'master' into SPARK-42960-query-cmd-new
WweiL Apr 18, 2023
dd4d54b
regenerate proto files
WweiL Apr 18, 2023
3b375cc
minor
WweiL Apr 18, 2023
34c28e7
minor import import SparkConnectService._
WweiL Apr 19, 2023
c487cd7
can you run tests one more time
WweiL Apr 19, 2023
0c2133b
please pass
WweiL Apr 19, 2023
2384ba1
merge master
WweiL Apr 19, 2023
da1d3e1
minor
WweiL Apr 19, 2023
eb19c2f
remove return None, remove optional tag in awaitTerminationResult
WweiL Apr 20, 2023
ed67070
add back return None
WweiL Apr 20, 2023
566e9fc
merge master
WweiL Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,21 @@ message StreamingQueryCommand {
bool process_all_available = 6;
// explain() API. Returns logical and physical plans.
ExplainCommand explain = 7;

// TODO(SPARK-42960) Add more commands: await_termination(), exception() etc.
// exception() API. Returns the exception in the query if any.
bool exception = 8;
// awaitTermination() API. Waits for the termination of the query.
AwaitTerminationCommand await_termination = 9;
}

message ExplainCommand {
// TODO: Consider reusing Explain from AnalyzePlanRequest message.
// We can not do this right now since it base.proto imports this file.
bool extended = 1;
}

message AwaitTerminationCommand {
optional int64 timeout_ms = 2;
}
}

// Response for commands on a streaming query.
Expand All @@ -279,6 +285,8 @@ message StreamingQueryCommandResult {
StatusResult status = 2;
RecentProgressResult recent_progress = 3;
ExplainResult explain = 4;
ExceptionResult exception = 5;
AwaitTerminationResult await_termination = 6;
}

message StatusResult {
Expand All @@ -298,6 +306,15 @@ message StreamingQueryCommandResult {
// Logical and physical plans as string
string result = 1;
}

message ExceptionResult {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need at least the stacktrace. Leave a comment about what our thinking is there.

Copy link
Contributor Author

@WweiL WweiL Apr 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree, currently the support for error is limited.

I tracked how they handle errors, and found that it's through here:

raise convert_exception(info, status.message) from None

And then in the convert_exception method:

elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
return StreamingQueryException(message)

Only the message is directly passed.

I guess we could file a ticket to wait until batch side's change, and then we could align with them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try to be consistent with what current exception() returns, which is return CapturedStreamingQueryException(msg, stackTrace, je.getCause()) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but satckTrace and cause are not included in connect's error framework so far. There is an ongoing PR about this: #40575.

// Exception message as string
optional string exception_message = 1;
}

message AwaitTerminationResult {
bool terminated = 1;
}
}

// Command to get the output of 'SparkContext.resources'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
import org.apache.spark.sql.connect.service.SparkConnectStreamHandler
import org.apache.spark.sql.connect.service.{SparkConnectService, SparkConnectStreamHandler}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
Expand Down Expand Up @@ -2255,6 +2255,23 @@ class SparkConnectPlanner(val session: SparkSession) {
.build()
respBuilder.setExplain(explain)

case StreamingQueryCommand.CommandCase.EXCEPTION =>
val result = query.exception
result.foreach(e =>
respBuilder.getExceptionBuilder
.setExceptionMessage(SparkConnectService.extractErrorMessage(e)))

case StreamingQueryCommand.CommandCase.AWAIT_TERMINATION =>
if (command.getAwaitTermination.hasTimeoutMs) {
val terminated = query.awaitTermination(command.getAwaitTermination.getTimeoutMs)
respBuilder.getAwaitTerminationBuilder
.setTerminated(terminated)
} else {
query.awaitTermination()
WweiL marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm .. just to be extra clear, it will be disconnected when it reaches gRPC timeout .. am i correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right this is intended at this stage. @rangadi will push update regarding this I believe

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client keep sending heart beat message to keep the RPC connection alive.
That said, we would still need to improve handling of this. E.g. it should exit if client side disconnects.

respBuilder.getAwaitTerminationBuilder
.setTerminated(true)
}

case StreamingQueryCommand.CommandCase.COMMAND_NOT_SET =>
throw new IllegalArgumentException("Missing command in StreamingQueryCommand")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class SparkConnectService(debug: Boolean)
}

private def buildStatusFromThrowable(st: Throwable): RPCStatus = {
val message = StringUtils.abbreviate(st.getMessage, 2048)
RPCStatus
.newBuilder()
.setCode(RPCCode.INTERNAL_VALUE)
Expand All @@ -86,7 +85,7 @@ class SparkConnectService(debug: Boolean)
.setDomain("org.apache.spark")
.putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName))))
.build()))
.setMessage(if (message != null) message else "")
.setMessage(SparkConnectService.extractErrorMessage(st))
.build()
}

Expand Down Expand Up @@ -295,4 +294,13 @@ object SparkConnectService {
}
}
}

def extractErrorMessage(st: Throwable): String = {
val message = StringUtils.abbreviate(st.getMessage, 2048)
if (message != null) {
message
} else {
""
}
}
}
81 changes: 63 additions & 18 deletions python/pyspark/sql/connect/proto/commands_pb2.py

Large diffs are not rendered by default.

120 changes: 118 additions & 2 deletions python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -876,13 +876,41 @@ class StreamingQueryCommand(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["extended", b"extended"]
) -> None: ...

class AwaitTerminationCommand(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

TIMEOUT_MS_FIELD_NUMBER: builtins.int
timeout_ms: builtins.int
def __init__(
self,
*,
timeout_ms: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_timeout_ms", b"_timeout_ms", "timeout_ms", b"timeout_ms"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_timeout_ms", b"_timeout_ms", "timeout_ms", b"timeout_ms"
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_timeout_ms", b"_timeout_ms"]
) -> typing_extensions.Literal["timeout_ms"] | None: ...

QUERY_ID_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
LAST_PROGRESS_FIELD_NUMBER: builtins.int
RECENT_PROGRESS_FIELD_NUMBER: builtins.int
STOP_FIELD_NUMBER: builtins.int
PROCESS_ALL_AVAILABLE_FIELD_NUMBER: builtins.int
EXPLAIN_FIELD_NUMBER: builtins.int
EXCEPTION_FIELD_NUMBER: builtins.int
AWAIT_TERMINATION_FIELD_NUMBER: builtins.int
@property
def query_id(self) -> global___StreamingQueryInstanceId:
"""(Required) Query instance. See `StreamingQueryInstanceId`."""
Expand All @@ -899,6 +927,11 @@ class StreamingQueryCommand(google.protobuf.message.Message):
@property
def explain(self) -> global___StreamingQueryCommand.ExplainCommand:
"""explain() API. Returns logical and physical plans."""
exception: builtins.bool
"""exception() API. Returns the exception in the query if any."""
@property
def await_termination(self) -> global___StreamingQueryCommand.AwaitTerminationCommand:
"""awaitTermination() API. Waits for the termination of the query."""
def __init__(
self,
*,
Expand All @@ -909,12 +942,18 @@ class StreamingQueryCommand(google.protobuf.message.Message):
stop: builtins.bool = ...,
process_all_available: builtins.bool = ...,
explain: global___StreamingQueryCommand.ExplainCommand | None = ...,
exception: builtins.bool = ...,
await_termination: global___StreamingQueryCommand.AwaitTerminationCommand | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"await_termination",
b"await_termination",
"command",
b"command",
"exception",
b"exception",
"explain",
b"explain",
"last_progress",
Expand All @@ -934,8 +973,12 @@ class StreamingQueryCommand(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"await_termination",
b"await_termination",
"command",
b"command",
"exception",
b"exception",
"explain",
b"explain",
"last_progress",
Expand All @@ -955,7 +998,14 @@ class StreamingQueryCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command", b"command"]
) -> typing_extensions.Literal[
"status", "last_progress", "recent_progress", "stop", "process_all_available", "explain"
"status",
"last_progress",
"recent_progress",
"stop",
"process_all_available",
"explain",
"exception",
"await_termination",
] | None: ...

global___StreamingQueryCommand = StreamingQueryCommand
Expand Down Expand Up @@ -1033,10 +1083,60 @@ class StreamingQueryCommandResult(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["result", b"result"]
) -> None: ...

class ExceptionResult(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

EXCEPTION_MESSAGE_FIELD_NUMBER: builtins.int
exception_message: builtins.str
"""Exception message as string"""
def __init__(
self,
*,
exception_message: builtins.str | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_exception_message",
b"_exception_message",
"exception_message",
b"exception_message",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_exception_message",
b"_exception_message",
"exception_message",
b"exception_message",
],
) -> None: ...
def WhichOneof(
self,
oneof_group: typing_extensions.Literal["_exception_message", b"_exception_message"],
) -> typing_extensions.Literal["exception_message"] | None: ...

class AwaitTerminationResult(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

TERMINATED_FIELD_NUMBER: builtins.int
terminated: builtins.bool
def __init__(
self,
*,
terminated: builtins.bool = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["terminated", b"terminated"]
) -> None: ...

QUERY_ID_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
RECENT_PROGRESS_FIELD_NUMBER: builtins.int
EXPLAIN_FIELD_NUMBER: builtins.int
EXCEPTION_FIELD_NUMBER: builtins.int
AWAIT_TERMINATION_FIELD_NUMBER: builtins.int
@property
def query_id(self) -> global___StreamingQueryInstanceId:
"""(Required) Query instance id. See `StreamingQueryInstanceId`."""
Expand All @@ -1046,17 +1146,27 @@ class StreamingQueryCommandResult(google.protobuf.message.Message):
def recent_progress(self) -> global___StreamingQueryCommandResult.RecentProgressResult: ...
@property
def explain(self) -> global___StreamingQueryCommandResult.ExplainResult: ...
@property
def exception(self) -> global___StreamingQueryCommandResult.ExceptionResult: ...
@property
def await_termination(self) -> global___StreamingQueryCommandResult.AwaitTerminationResult: ...
def __init__(
self,
*,
query_id: global___StreamingQueryInstanceId | None = ...,
status: global___StreamingQueryCommandResult.StatusResult | None = ...,
recent_progress: global___StreamingQueryCommandResult.RecentProgressResult | None = ...,
explain: global___StreamingQueryCommandResult.ExplainResult | None = ...,
exception: global___StreamingQueryCommandResult.ExceptionResult | None = ...,
await_termination: global___StreamingQueryCommandResult.AwaitTerminationResult | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"await_termination",
b"await_termination",
"exception",
b"exception",
"explain",
b"explain",
"query_id",
Expand All @@ -1072,6 +1182,10 @@ class StreamingQueryCommandResult(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"await_termination",
b"await_termination",
"exception",
b"exception",
"explain",
b"explain",
"query_id",
Expand All @@ -1086,7 +1200,9 @@ class StreamingQueryCommandResult(google.protobuf.message.Message):
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["result_type", b"result_type"]
) -> typing_extensions.Literal["status", "recent_progress", "explain"] | None: ...
) -> typing_extensions.Literal[
"status", "recent_progress", "explain", "exception", "await_termination"
] | None: ...

global___StreamingQueryCommandResult = StreamingQueryCommandResult

Expand Down
28 changes: 23 additions & 5 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
)
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)

__all__ = [
"StreamingQuery", # TODO(SPARK-43032): "StreamingQueryManager"
Expand Down Expand Up @@ -66,11 +69,21 @@ def isActive(self) -> bool:

isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__

# TODO (SPARK-42960): Implement and uncomment the doc
def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]:
raise NotImplementedError()
cmd = pb2.StreamingQueryCommand()
if timeout is not None:
if not isinstance(timeout, (int, float)) or timeout <= 0:
raise ValueError("timeout must be a positive integer or float. Got %s" % timeout)
cmd.await_termination.timeout_ms = int(timeout * 1000)
terminated = self._execute_streaming_query_cmd(cmd).await_termination.terminated
return terminated
else:
await_termination_cmd = pb2.StreamingQueryCommand.AwaitTerminationCommand()
cmd.await_termination.CopyFrom(await_termination_cmd)
self._execute_streaming_query_cmd(cmd)
return None
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved

# awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__
awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__

@property
def status(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -127,9 +140,14 @@ def explain(self, extended: bool = False) -> None:

explain.__doc__ = PySparkStreamingQuery.explain.__doc__

# TODO (SPARK-42960): Implement and uncomment the doc
def exception(self) -> Optional[StreamingQueryException]:
raise NotImplementedError()
cmd = pb2.StreamingQueryCommand()
cmd.exception = True
exception = self._execute_streaming_query_cmd(cmd).exception
if exception.HasField("exception_message"):
return CapturedStreamingQueryException(exception.exception_message)
else:
return None
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved

exception.__doc__ = PySparkStreamingQuery.exception.__doc__

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]:
>>> sq.stop()
"""
if timeout is not None:
if not isinstance(timeout, (int, float)) or timeout < 0:
if not isinstance(timeout, (int, float)) or timeout <= 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise ValueError("timeout must be a positive integer or float. Got %s" % timeout)
return self._jsq.awaitTermination(int(timeout * 1000))
else:
Expand Down
Loading