Skip to content

Commit

Permalink
Checkpoint and localCheckpoint in Spark Connect
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 17, 2024
1 parent 714fc8c commit ec0d893
Show file tree
Hide file tree
Showing 19 changed files with 691 additions and 356 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ message AnalyzePlanRequest {
Persist persist = 14;
Unpersist unpersist = 15;
GetStorageLevel get_storage_level = 16;
Checkpoint checkpoint = 18;
}

message Schema {
Expand Down Expand Up @@ -199,6 +200,18 @@ message AnalyzePlanRequest {
// (Required) The logical plan to get the storage level.
Relation relation = 1;
}

message Checkpoint {
// (Required) The logical plan to checkpoint.
Relation relation = 1;

// (Optional) Locally checkpoint using a local temporary
// directory in Spark Connect server (Spark Driver)
optional bool local = 2;

// (Optional) Whether to checkpoint this dataframe immediately.
optional bool eager = 3;
}
}

// Response to performing analysis of the query. Contains relevant metadata to be able to
Expand All @@ -224,6 +237,7 @@ message AnalyzePlanResponse {
Persist persist = 12;
Unpersist unpersist = 13;
GetStorageLevel get_storage_level = 14;
Checkpoint checkpoint = 16;
}

message Schema {
Expand Down Expand Up @@ -275,6 +289,11 @@ message AnalyzePlanResponse {
// (Required) The StorageLevel as a result of get_storage_level request.
StorageLevel storage_level = 1;
}

message Checkpoint {
// (Required) The logical plan checkpointed.
CachedRemoteRelation relation = 1;
}
}

// A request to be executed by the service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message Command {
StreamingQueryListenerBusCommand streaming_query_listener_bus_command = 11;
CommonInlineUserDefinedDataSource register_data_source = 12;
CreateResourceProfileCommand create_resource_profile_command = 13;
RemoveCachedRemoteRelationCommand remove_cached_remote_relation_command = 14;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -484,3 +485,9 @@ message CreateResourceProfileCommandResult {
// (Required) Server-side generated resource profile id.
int32 profile_id = 1;
}

// Command to remove `CashedRemoteRelation`
message RemoveCachedRemoteRelationCommand {
// (Required) The remote to be related
CachedRemoteRelation relation = 1;
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.SESSION_ID
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
Expand Down Expand Up @@ -2581,6 +2581,8 @@ class SparkConnectPlanner(
handleCreateResourceProfileCommand(
command.getCreateResourceProfileCommand,
responseObserver)
case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND =>
handleRemoveCachedRemoteRelationCommand(command.getRemoveCachedRemoteRelationCommand)

case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
Expand Down Expand Up @@ -3507,6 +3509,14 @@ class SparkConnectPlanner(
.build())
}

private def handleRemoveCachedRemoteRelationCommand(
removeCachedRemoteRelationCommand: proto.RemoveCachedRemoteRelationCommand): Unit = {
val dfId = removeCachedRemoteRelationCommand.getRelation.getRelationId
logInfo(log"Removing DataFrame with id ${MDC(DATAFRAME_ID, dfId)} from the cache")
sessionHolder.removeCachedDataFrame(dfId)
executeHolder.eventsManager.postFinished()
}

private val emptyLocalRelation = LocalRelation(
output = AttributeReference("value", StringType, false)() :: Nil,
data = Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
val eventManager: SessionEventsManager = SessionEventsManager(this, new SystemClock())

// Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like
// foreachBatch() in Streaming. Lazy since most sessions don't need it.
private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()
// foreachBatch() in Streaming, and DataFrame.checkpoint API. Lazy since most sessions don't
// need it.
private[spark] lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()

// Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
// StreamingQueryManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.connect.service

import java.util.UUID

import scala.jdk.CollectionConverters._

import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.DATAFRAME_ID
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
Expand Down Expand Up @@ -206,6 +209,29 @@ private[connect] class SparkConnectAnalyzeHandler(
.setStorageLevel(StorageLevelProtoConverter.toConnectProtoType(storageLevel))
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.CHECKPOINT =>
val target = Dataset
.ofRows(session, planner.transformRelation(request.getCheckpoint.getRelation))
val checkpointed = if (request.getCheckpoint.hasLocal && request.getCheckpoint.hasEager) {
target.localCheckpoint(eager = request.getCheckpoint.getEager)
} else if (request.getCheckpoint.hasLocal) {
target.localCheckpoint()
} else if (request.getCheckpoint.hasEager) {
target.checkpoint(eager = request.getCheckpoint.getEager)
} else {
target.checkpoint()
}

val dfId = UUID.randomUUID().toString
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
sessionHolder.cacheDataFrameById(dfId, checkpointed)

builder.setCheckpoint(
proto.AnalyzePlanResponse.Checkpoint
.newBuilder()
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build())
.build())

case other => throw InvalidPlanInput(s"Unknown Analyze Method $other!")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ object SparkConnectService extends Logging {
previoslyObservedSessionId)
}

// For testing
private[spark] def getOrCreateIsolatedSession(
userId: String, sessionId: String): SessionHolder = {
getOrCreateIsolatedSession(userId, sessionId, None)
}

/**
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class SparkFrameMethodsParityTests(
SparkFrameMethodsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on checkpoint which is not supported from Spark Connect.")
@unittest.skip("Test depends on SparkContext which is not supported from Spark Connect.")
def test_checkpoint(self):
super().test_checkpoint()

Expand All @@ -34,10 +34,6 @@ def test_checkpoint(self):
def test_coalesce(self):
super().test_coalesce()

@unittest.skip("Test depends on localCheckpoint which is not supported from Spark Connect.")
def test_local_checkpoint(self):
super().test_local_checkpoint()

@unittest.skip(
"Test depends on RDD, and cannot use SQL expression due to Catalyst optimization"
)
Expand Down
19 changes: 18 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
from pyspark.sql.connect.conversion import (
storage_level_to_proto,
proto_to_storage_level,
proto_to_remote_cached_dataframe,
)
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
Expand Down Expand Up @@ -100,6 +104,7 @@
from google.rpc.error_details_pb2 import ErrorInfo
from pyspark.sql.connect._typing import DataTypeOrString
from pyspark.sql.datasource import DataSource
from pyspark.sql.connect.dataframe import DataFrame


class ChannelBuilder:
Expand Down Expand Up @@ -528,6 +533,7 @@ def __init__(
is_same_semantics: Optional[bool],
semantic_hash: Optional[int],
storage_level: Optional[StorageLevel],
replaced: Optional["DataFrame"],
):
self.schema = schema
self.explain_string = explain_string
Expand All @@ -540,6 +546,7 @@ def __init__(
self.is_same_semantics = is_same_semantics
self.semantic_hash = semantic_hash
self.storage_level = storage_level
self.replaced = replaced

@classmethod
def fromProto(cls, pb: Any) -> "AnalyzeResult":
Expand All @@ -554,6 +561,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
is_same_semantics: Optional[bool] = None
semantic_hash: Optional[int] = None
storage_level: Optional[StorageLevel] = None
replaced: Optional["DataFrame"] = None

if pb.HasField("schema"):
schema = types.proto_schema_to_pyspark_data_type(pb.schema.schema)
Expand Down Expand Up @@ -581,6 +589,8 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
pass
elif pb.HasField("get_storage_level"):
storage_level = proto_to_storage_level(pb.get_storage_level.storage_level)
elif pb.HasField("checkpoint"):
replaced = proto_to_remote_cached_dataframe(pb.checkpoint.relation)
else:
raise SparkConnectException("No analyze result found!")

Expand All @@ -596,6 +606,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
is_same_semantics,
semantic_hash,
storage_level,
replaced,
)


Expand Down Expand Up @@ -1229,6 +1240,12 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
req.unpersist.blocking = cast(bool, kwargs.get("blocking"))
elif method == "get_storage_level":
req.get_storage_level.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
elif method == "checkpoint":
req.checkpoint.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
if kwargs.get("local", None) is not None:
req.checkpoint.local = cast(bool, kwargs.get("local"))
if kwargs.get("eager", None) is not None:
req.checkpoint.eager = cast(bool, kwargs.get("eager"))
else:
raise PySparkValueError(
error_class="UNSUPPORTED_OPERATION",
Expand Down
23 changes: 17 additions & 6 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@
import pyspark.sql.connect.proto as pb2
from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names

from typing import (
Any,
Callable,
Sequence,
List,
)
from typing import Any, Callable, Sequence, List, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame


class LocalDataToArrowConversion:
Expand Down Expand Up @@ -570,3 +568,16 @@ def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)


def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "DataFrame":
assert relation is not None and isinstance(relation, pb2.CachedRemoteRelation)

from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.session import SparkSession
import pyspark.sql.connect.plan as plan

return DataFrame(
plan=plan.CachedRemoteRelation(relation.relation_id),
session=SparkSession.active(),
)
71 changes: 57 additions & 14 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

# mypy: disable-error-code="override"

from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.errors.exceptions.base import (
SessionNotSameException,
PySparkIndexError,
Expand Down Expand Up @@ -138,6 +138,41 @@ def __init__(
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
self._cached_remote_relation_id: Optional[str] = None

def __del__(self) -> None:
# If session is already closed, all cached DataFrame should be released.
if not self._session.client.is_closed and self._cached_remote_relation_id is not None:
try:
command = plan.RemoveRemoteCachedRelation(
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
).command(session=self._session.client)
req = self._session.client._execute_plan_request_with_metadata()
if self._session.client._user_id:
req.user_context.user_id = self._session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in self._session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = self._session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = self._session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")

def __reduce__(self) -> Tuple:
"""
Expand Down Expand Up @@ -2096,19 +2131,29 @@ def writeTo(self, table: str) -> "DataFrameWriterV2":
def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)

if not is_remote_only():
def checkpoint(self, eager: bool = True) -> "DataFrame":
relation = self._plan.plan(self._session.client)
result = self._session.client._analyze(
method="checkpoint", relation=relation, local=False, eager=eager
)
assert result.replaced is not None
assert isinstance(result.replaced._plan, plan.CachedRemoteRelation)
checkpointed = result.replaced
checkpointed._cached_remote_relation_id = result.replaced._plan._relationId
return checkpointed

def checkpoint(self, eager: bool = True) -> "DataFrame":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "checkpoint()"},
)
def localCheckpoint(self, eager: bool = True) -> "DataFrame":
relation = self._plan.plan(self._session.client)
result = self._session.client._analyze(
method="checkpoint", relation=relation, local=False, eager=eager
)
assert result.replaced is not None
assert isinstance(result.replaced._plan, plan.CachedRemoteRelation)
checkpointed = result.replaced
checkpointed._cached_remote_relation_id = result.replaced._plan._relationId
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "localCheckpoint()"},
)
if not is_remote_only():

def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
raise PySparkNotImplementedError(
Expand Down Expand Up @@ -2203,8 +2248,6 @@ def _test() -> None:
if not is_remote_only():
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
del pyspark.sql.dataframe.DataFrame.checkpoint.__doc__
del pyspark.sql.dataframe.DataFrame.localCheckpoint.__doc__

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
Expand Down
Loading

0 comments on commit ec0d893

Please sign in to comment.