Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48258][PYTHON][CONNECT] Checkpoint and localCheckpoint in Spark Connect #46570

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ message ExecutePlanResponse {
// (Optional) Intermediate query progress reports.
ExecutionProgress execution_progress = 18;

// Response for command that checkpoints a DataFrame.
CheckpointCommandResult checkpoint_command_result = 19;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -1048,6 +1051,11 @@ message FetchErrorDetailsResponse {
}
}

message CheckpointCommandResult {
// (Required) The logical plan checkpointed.
CachedRemoteRelation relation = 1;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ message Command {
StreamingQueryListenerBusCommand streaming_query_listener_bus_command = 11;
CommonInlineUserDefinedDataSource register_data_source = 12;
CreateResourceProfileCommand create_resource_profile_command = 13;
CheckpointCommand checkpoint_command = 14;
RemoveCachedRemoteRelationCommand remove_cached_remote_relation_command = 15;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -484,3 +486,21 @@ message CreateResourceProfileCommandResult {
// (Required) Server-side generated resource profile id.
int32 profile_id = 1;
}

// Command to remove `CashedRemoteRelation`
message RemoveCachedRemoteRelationCommand {
// (Required) The remote to be related
CachedRemoteRelation relation = 1;
}

message CheckpointCommand {
// (Required) The logical plan to checkpoint.
Relation relation = 1;

// (Optional) Locally checkpoint using a local temporary
// directory in Spark Connect server (Spark Driver)
optional bool local = 2;

// (Optional) Whether to checkpoint this dataframe immediately.
optional bool eager = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default is eager = true right? Should the protocol encode this better? Currently the protocol defaults to eager = false if the field is not set, so my question is should we flip the logic (i.e. replace this with lazy) so the default behavior does not require you to set additional fields.

The same question for local...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but I actually followed other cases though (see optional bool at relations.proto).

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connect.planner

import java.util.UUID

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try
Expand All @@ -33,13 +35,13 @@ import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.SESSION_ID
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
Expand Down Expand Up @@ -2581,6 +2583,10 @@ class SparkConnectPlanner(
handleCreateResourceProfileCommand(
command.getCreateResourceProfileCommand,
responseObserver)
case proto.Command.CommandTypeCase.CHECKPOINT_COMMAND =>
handleCheckpointCommand(command.getCheckpointCommand, responseObserver)
case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND =>
handleRemoveCachedRemoteRelationCommand(command.getRemoveCachedRemoteRelationCommand)

case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
Expand Down Expand Up @@ -3507,6 +3513,47 @@ class SparkConnectPlanner(
.build())
}

private def handleCheckpointCommand(
checkpointCommand: CheckpointCommand,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
val target = Dataset
.ofRows(session, transformRelation(checkpointCommand.getRelation))
val checkpointed = if (checkpointCommand.hasLocal && checkpointCommand.hasEager) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also increase the visibility of Dataset.checkpoint(eager: Boolean, reliableCheckpoint: Boolean).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did it locally but I think it's actually better to keep just as is .. I think the current one is easier to read .. Dataset.checkpoint(eager: Boolean, reliableCheckpoint: Boolean) is private as well.

target.localCheckpoint(eager = checkpointCommand.getEager)
} else if (checkpointCommand.hasLocal) {
target.localCheckpoint()
} else if (checkpointCommand.hasEager) {
target.checkpoint(eager = checkpointCommand.getEager)
} else {
target.checkpoint()
}

val dfId = UUID.randomUUID().toString
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
sessionHolder.cacheDataFrameById(dfId, checkpointed)

executeHolder.eventsManager.postFinished()
responseObserver.onNext(
proto.ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setServerSideSessionId(sessionHolder.serverSessionId)
.setCheckpointCommandResult(
proto.CheckpointCommandResult
.newBuilder()
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build())
.build())
.build())
}

private def handleRemoveCachedRemoteRelationCommand(
removeCachedRemoteRelationCommand: proto.RemoveCachedRemoteRelationCommand): Unit = {
val dfId = removeCachedRemoteRelationCommand.getRelation.getRelationId
logInfo(log"Removing DataFrame with id ${MDC(DATAFRAME_ID, dfId)} from the cache")
sessionHolder.removeCachedDataFrame(dfId)
executeHolder.eventsManager.postFinished()
}

private val emptyLocalRelation = LocalRelation(
output = AttributeReference("value", StringType, false)() :: Nil,
data = Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
val eventManager: SessionEventsManager = SessionEventsManager(this, new SystemClock())

// Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like
// foreachBatch() in Streaming. Lazy since most sessions don't need it.
private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()
// foreachBatch() in Streaming, and DataFrame.checkpoint API. Lazy since most sessions don't
// need it.
private[spark] lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved

// Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
// StreamingQueryManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ object SparkConnectService extends Logging {
previoslyObservedSessionId)
}

// For testing
private[spark] def getOrCreateIsolatedSession(
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved
userId: String, sessionId: String): SessionHolder = {
getOrCreateIsolatedSession(userId, sessionId, None)
}

/**
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class SparkFrameMethodsParityTests(
SparkFrameMethodsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on checkpoint which is not supported from Spark Connect.")
@unittest.skip("Test depends on SparkContext which is not supported from Spark Connect.")
def test_checkpoint(self):
super().test_checkpoint()

Expand All @@ -34,10 +34,6 @@ def test_checkpoint(self):
def test_coalesce(self):
super().test_coalesce()

@unittest.skip("Test depends on localCheckpoint which is not supported from Spark Connect.")
def test_local_checkpoint(self):
super().test_local_checkpoint()

@unittest.skip(
"Test depends on RDD, and cannot use SQL expression due to Catalyst optimization"
)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
from pyspark.sql.connect.conversion import (
storage_level_to_proto,
proto_to_storage_level,
proto_to_remote_cached_dataframe,
)
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
Expand Down Expand Up @@ -1400,6 +1404,12 @@ def handle_response(
if b.HasField("create_resource_profile_command_result"):
profile_id = b.create_resource_profile_command_result.profile_id
yield {"create_resource_profile_command_result": profile_id}
if b.HasField("checkpoint_command_result"):
yield {
"checkpoint_command_result": proto_to_remote_cached_dataframe(
b.checkpoint_command_result.relation
)
}

try:
if self._use_reattachable_execute:
Expand Down
23 changes: 17 additions & 6 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@
import pyspark.sql.connect.proto as pb2
from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names

from typing import (
Any,
Callable,
Sequence,
List,
)
from typing import Any, Callable, Sequence, List, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame


class LocalDataToArrowConversion:
Expand Down Expand Up @@ -570,3 +568,16 @@ def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)


def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "DataFrame":
assert relation is not None and isinstance(relation, pb2.CachedRemoteRelation)

from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.session import SparkSession
import pyspark.sql.connect.plan as plan

return DataFrame(
plan=plan.CachedRemoteRelation(relation.relation_id),
session=SparkSession.active(),
)
69 changes: 54 additions & 15 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

# mypy: disable-error-code="override"

from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.errors.exceptions.base import (
SessionNotSameException,
PySparkIndexError,
Expand Down Expand Up @@ -138,6 +138,41 @@ def __init__(
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
self._cached_remote_relation_id: Optional[str] = None

def __del__(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no Python GC expert here. I am assuming some system thread is doing this work. Is it wise to execute an RPC from there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least Py4J does the same thing (socket connection).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and that (previous) maintainer knows quite well. Although now I am the maintainer for Py4J though :-).

# If session is already closed, all cached DataFrame should be released.
if not self._session.client.is_closed and self._cached_remote_relation_id is not None:
try:
command = plan.RemoveRemoteCachedRelation(
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
).command(session=self._session.client)
req = self._session.client._execute_plan_request_with_metadata()
if self._session.client._user_id:
req.user_context.user_id = self._session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in self._session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = self._session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = self._session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")

def __reduce__(self) -> Tuple:
"""
Expand Down Expand Up @@ -2096,19 +2131,25 @@ def writeTo(self, table: str) -> "DataFrameWriterV2":
def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)

if not is_remote_only():
def checkpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties = self._session.client.execute_command(cmd.command(self._session.client))
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties = self._session.client.execute_command(cmd.command(self._session.client))
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

def checkpoint(self, eager: bool = True) -> "DataFrame":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "checkpoint()"},
)

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "localCheckpoint()"},
)
if not is_remote_only():

def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
raise PySparkNotImplementedError(
Expand Down Expand Up @@ -2203,8 +2244,6 @@ def _test() -> None:
if not is_remote_only():
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
del pyspark.sql.dataframe.DataFrame.checkpoint.__doc__
del pyspark.sql.dataframe.DataFrame.localCheckpoint.__doc__

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
Expand Down
32 changes: 31 additions & 1 deletion python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,9 +1785,39 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
return cmd


# Catalog API (internal-only)
class RemoveRemoteCachedRelation(LogicalPlan):
def __init__(self, relation: CachedRemoteRelation) -> None:
super().__init__(None)
self._relation = relation

def command(self, session: "SparkConnectClient") -> proto.Command:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relation._relationId
cmd = proto.Command()
cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
return cmd


class Checkpoint(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], local: bool, eager: bool) -> None:
super().__init__(child)
self._local = local
self._eager = eager

def command(self, session: "SparkConnectClient") -> proto.Command:
cmd = proto.Command()
assert self._child is not None
cmd.checkpoint_command.CopyFrom(
proto.CheckpointCommand(
relation=self._child.plan(session),
local=self._local,
eager=self._eager,
)
)
return cmd


# Catalog API (internal-only)
class CurrentDatabase(LogicalPlan):
def __init__(self) -> None:
super().__init__(None)
Expand Down
214 changes: 108 additions & 106 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

Loading