Skip to content

Commit

Permalink
Experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
changgyoopark-db committed Jan 23, 2025
1 parent 43657b2 commit 337871f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 48 deletions.
4 changes: 1 addition & 3 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,7 @@ def test_extended_hint_types(self):
self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df))

with io.StringIO() as buf, redirect_stdout(buf):
# the plan cache may hold a fully analyzed plan
with self.sql_conf({"spark.connect.session.planCache.enabled": False}):
hinted_df.explain(True)
hinted_df.explain(True)
explain_output = buf.getvalue()
self.assertGreaterEqual(explain_output.count("1.2345"), 1)
self.assertGreaterEqual(explain_output.count("what"), 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
} else {
DoNotCleanup
}
val rel = request.getPlan.getRoot
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
tracker,
shuffleCleanupMode)
sessionHolder
.updatePlanCache(
rel,
Dataset.ofRows(
session,
planner.transformRelation(rel, cachePlan = true),
tracker,
shuffleCleanupMode))
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,26 +469,27 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): LogicalPlan =
planCache match {
case Some(cache) if cachePlan && planCacheEnabled =>
val analyzedPlan = if (plan.analyzed) {
plan
} else {
val qe = session.sessionState.executePlan(plan)
if (qe.isLazyAnalysis) {
// The plan is intended to be lazily analyzed.
plan
} else {
// Make sure that the plan is fully analyzed before being cached.
qe.analyzed
}
}
cache.put(rel, analyzedPlan)
analyzedPlan
cache.put(rel, plan)
plan
case _ => plan
}

getPlanCache(rel).getOrElse(putPlanCache(rel, transform(rel)))
}

/**
* Update the plan cache with the supplied data frame.
*/
private[connect] def updatePlanCache(rel: proto.Relation, df: DataFrame): DataFrame = {
if (!df.queryExecution.logical.analyzed &&
rel.hasCommon && rel.getCommon.hasPlanId &&
Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))) {
planCache.foreach(_.put(rel, df.queryExecution.analyzed))
}
df
}

// For testing. Expose the plan cache for testing purposes.
private[service] def getPlanCache: Option[Cache[proto.Relation, LogicalPlan]] = planCache
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ private[connect] class SparkConnectAnalyzeHandler(

request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
val schema = Dataset
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
val rel = request.getSchema.getPlan.getRoot
val schema = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
Expand All @@ -73,8 +74,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
val queryExecution = Dataset
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
val rel = request.getExplain.getPlan.getRoot
val queryExecution = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
Expand All @@ -96,8 +98,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
val schema = Dataset
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
val rel = request.getTreeString.getPlan.getRoot
val schema = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
Expand All @@ -111,8 +114,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
val isLocal = Dataset
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
val rel = request.getIsLocal.getPlan.getRoot
val isLocal = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
Expand All @@ -121,8 +125,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
val isStreaming = Dataset
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
val rel = request.getIsStreaming.getPlan.getRoot
val isStreaming = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
Expand All @@ -131,8 +136,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
val inputFiles = Dataset
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
val rel = request.getInputFiles.getPlan.getRoot
val inputFiles = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
Expand All @@ -156,29 +162,37 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
val targetRel = request.getSameSemantics.getTargetPlan.getRoot
val target = sessionHolder
.updatePlanCache(targetRel, Dataset.ofRows(session, transformRelation(targetRel)))
val otherRel = request.getSameSemantics.getOtherPlan.getRoot
val other = sessionHolder
.updatePlanCache(otherRel, Dataset.ofRows(session, transformRelation(otherRel)))
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val semanticHash = Dataset
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
val rel = request.getSemanticHash.getPlan.getRoot
val semanticHash = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
.semanticHash()
builder.setSemanticHash(
proto.AnalyzePlanResponse.SemanticHash
.newBuilder()
.setResult(semanticHash))

case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getPersist.getRelation))
val rel = request.getPersist.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
if (request.getPersist.hasStorageLevel) {
target.persist(
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
Expand All @@ -188,8 +202,12 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
val rel = request.getUnpersist.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
if (request.getUnpersist.hasBlocking) {
target.unpersist(request.getUnpersist.getBlocking)
} else {
Expand All @@ -198,8 +216,12 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
val target = Dataset
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
val rel = request.getGetStorageLevel.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
val storageLevel = target.storageLevel
builder.setGetStorageLevel(
proto.AnalyzePlanResponse.GetStorageLevel
Expand Down

0 comments on commit 337871f

Please sign in to comment.