-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-48258][PYTHON][CONNECT] Checkpoint and localCheckpoint in Spark Connect #46570
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ | |
|
||
package org.apache.spark.sql.connect.planner | ||
|
||
import java.util.UUID | ||
|
||
import scala.collection.mutable | ||
import scala.jdk.CollectionConverters._ | ||
import scala.util.Try | ||
|
@@ -33,13 +35,13 @@ import org.apache.spark.{Partition, SparkEnv, TaskContext} | |
import org.apache.spark.annotation.{DeveloperApi, Since} | ||
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} | ||
import org.apache.spark.connect.proto | ||
import org.apache.spark.connect.proto.{CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} | ||
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} | ||
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult | ||
import org.apache.spark.connect.proto.Parse.ParseFormat | ||
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance | ||
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase | ||
import org.apache.spark.internal.{Logging, MDC} | ||
import org.apache.spark.internal.LogKeys.SESSION_ID | ||
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} | ||
import org.apache.spark.ml.{functions => MLFunctions} | ||
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} | ||
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession} | ||
|
@@ -2581,6 +2583,10 @@ class SparkConnectPlanner( | |
handleCreateResourceProfileCommand( | ||
command.getCreateResourceProfileCommand, | ||
responseObserver) | ||
case proto.Command.CommandTypeCase.CHECKPOINT_COMMAND => | ||
handleCheckpointCommand(command.getCheckpointCommand, responseObserver) | ||
case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND => | ||
handleRemoveCachedRemoteRelationCommand(command.getRemoveCachedRemoteRelationCommand) | ||
|
||
case _ => throw new UnsupportedOperationException(s"$command not supported.") | ||
} | ||
|
@@ -3507,6 +3513,47 @@ class SparkConnectPlanner( | |
.build()) | ||
} | ||
|
||
private def handleCheckpointCommand( | ||
checkpointCommand: CheckpointCommand, | ||
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { | ||
val target = Dataset | ||
.ofRows(session, transformRelation(checkpointCommand.getRelation)) | ||
val checkpointed = if (checkpointCommand.hasLocal && checkpointCommand.hasEager) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could also increase the visibility of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just did it locally but I think it's actually better to keep just as is .. I think the current one is easier to read .. |
||
target.localCheckpoint(eager = checkpointCommand.getEager) | ||
} else if (checkpointCommand.hasLocal) { | ||
target.localCheckpoint() | ||
} else if (checkpointCommand.hasEager) { | ||
target.checkpoint(eager = checkpointCommand.getEager) | ||
} else { | ||
target.checkpoint() | ||
} | ||
|
||
val dfId = UUID.randomUUID().toString | ||
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}") | ||
sessionHolder.cacheDataFrameById(dfId, checkpointed) | ||
|
||
executeHolder.eventsManager.postFinished() | ||
responseObserver.onNext( | ||
proto.ExecutePlanResponse | ||
.newBuilder() | ||
.setSessionId(sessionId) | ||
.setServerSideSessionId(sessionHolder.serverSessionId) | ||
.setCheckpointCommandResult( | ||
proto.CheckpointCommandResult | ||
.newBuilder() | ||
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build()) | ||
.build()) | ||
.build()) | ||
} | ||
|
||
private def handleRemoveCachedRemoteRelationCommand( | ||
removeCachedRemoteRelationCommand: proto.RemoveCachedRemoteRelationCommand): Unit = { | ||
val dfId = removeCachedRemoteRelationCommand.getRelation.getRelationId | ||
logInfo(log"Removing DataFrame with id ${MDC(DATAFRAME_ID, dfId)} from the cache") | ||
sessionHolder.removeCachedDataFrame(dfId) | ||
executeHolder.eventsManager.postFinished() | ||
} | ||
|
||
private val emptyLocalRelation = LocalRelation( | ||
output = AttributeReference("value", StringType, false)() :: Nil, | ||
data = Seq.empty) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
# | ||
|
||
# mypy: disable-error-code="override" | ||
|
||
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 | ||
from pyspark.errors.exceptions.base import ( | ||
SessionNotSameException, | ||
PySparkIndexError, | ||
|
@@ -138,6 +138,41 @@ def __init__( | |
# by __repr__ and _repr_html_ while eager evaluation opens. | ||
self._support_repr_html = False | ||
self._cached_schema: Optional[StructType] = None | ||
self._cached_remote_relation_id: Optional[str] = None | ||
|
||
def __del__(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So no Python GC expert here. I am assuming some system thread is doing this work. Is it wise to execute an RPC from there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least Py4J does the same thing (socket connection). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and that (previous) maintainer knows quite well. Although now I am the maintainer for Py4J though :-). |
||
# If session is already closed, all cached DataFrame should be released. | ||
if not self._session.client.is_closed and self._cached_remote_relation_id is not None: | ||
try: | ||
command = plan.RemoveRemoteCachedRelation( | ||
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id) | ||
).command(session=self._session.client) | ||
req = self._session.client._execute_plan_request_with_metadata() | ||
if self._session.client._user_id: | ||
req.user_context.user_id = self._session.client._user_id | ||
req.plan.command.CopyFrom(command) | ||
|
||
for attempt in self._session.client._retrying(): | ||
with attempt: | ||
# !!HACK ALERT!! | ||
# unary_stream does not work on Python's exit for an unknown reasons | ||
# Therefore, here we open unary_unary channel instead. | ||
# See also :class:`SparkConnectServiceStub`. | ||
request_serializer = ( | ||
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString | ||
) | ||
response_deserializer = ( | ||
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString | ||
) | ||
channel = self._session.client._channel.unary_unary( | ||
"/spark.connect.SparkConnectService/ExecutePlan", | ||
request_serializer=request_serializer, | ||
response_deserializer=response_deserializer, | ||
) | ||
metadata = self._session.client._builder.metadata() | ||
channel(req, metadata=metadata) # type: ignore[arg-type] | ||
except Exception as e: | ||
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") | ||
|
||
def __reduce__(self) -> Tuple: | ||
""" | ||
|
@@ -2096,19 +2131,25 @@ def writeTo(self, table: str) -> "DataFrameWriterV2": | |
def offset(self, n: int) -> ParentDataFrame: | ||
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) | ||
|
||
if not is_remote_only(): | ||
def checkpoint(self, eager: bool = True) -> "DataFrame": | ||
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) | ||
_, properties = self._session.client.execute_command(cmd.command(self._session.client)) | ||
assert "checkpoint_command_result" in properties | ||
checkpointed = properties["checkpoint_command_result"] | ||
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) | ||
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId | ||
return checkpointed | ||
|
||
def localCheckpoint(self, eager: bool = True) -> "DataFrame": | ||
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) | ||
_, properties = self._session.client.execute_command(cmd.command(self._session.client)) | ||
assert "checkpoint_command_result" in properties | ||
checkpointed = properties["checkpoint_command_result"] | ||
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) | ||
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId | ||
return checkpointed | ||
|
||
def checkpoint(self, eager: bool = True) -> "DataFrame": | ||
raise PySparkNotImplementedError( | ||
error_class="NOT_IMPLEMENTED", | ||
message_parameters={"feature": "checkpoint()"}, | ||
) | ||
|
||
def localCheckpoint(self, eager: bool = True) -> "DataFrame": | ||
raise PySparkNotImplementedError( | ||
error_class="NOT_IMPLEMENTED", | ||
message_parameters={"feature": "localCheckpoint()"}, | ||
) | ||
if not is_remote_only(): | ||
|
||
def toJSON(self, use_unicode: bool = True) -> "RDD[str]": | ||
raise PySparkNotImplementedError( | ||
|
@@ -2203,8 +2244,6 @@ def _test() -> None: | |
if not is_remote_only(): | ||
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__ | ||
del pyspark.sql.dataframe.DataFrame.rdd.__doc__ | ||
del pyspark.sql.dataframe.DataFrame.checkpoint.__doc__ | ||
del pyspark.sql.dataframe.DataFrame.localCheckpoint.__doc__ | ||
|
||
globs["spark"] = ( | ||
PySparkSession.builder.appName("sql.connect.dataframe tests") | ||
|
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default is
eager = true
right? Should the protocol encode this better? Currently the protocol defaults toeager = false
if the field is not set, so my question is should we flip the logic (i.e. replace this withlazy
) so the default behavior does not require you to set additional fields.The same question for
local
...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can but I actually followed other cases though (see
optional bool
atrelations.proto
).