Skip to content

Commit

Permalink
Impl
Browse files Browse the repository at this point in the history
  • Loading branch information
changgyoopark-db committed Jan 23, 2025
1 parent 1a49237 commit 7b7430f
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
} else {
DoNotCleanup
}
val rel = request.getPlan.getRoot
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
tracker,
shuffleCleanupMode)
sessionHolder
.updatePlanCache(
rel,
Dataset.ofRows(session, planner.transformRelation(rel), tracker, shuffleCleanupMode))
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,16 @@ class SparkConnectPlanner(
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

/**
* The root of the query plan is a relation and we apply the transformations to it. The resolved
* logical plan will not get cached. If the result needs to be cached, use
* `transformRelation(rel, cachePlan = true)` instead.
* @param rel
* The relation to transform.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation): LogicalPlan =
transformRelation(rel, cachePlan = false)

/**
* The root of the query plan is a relation and we apply the transformations to it.
* @param rel
* The relation to transform.
* @param cachePlan
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
* upon by further dataset transformation. The default is false.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
def transformRelation(rel: proto.Relation): LogicalPlan = {
sessionHolder.usePlanCache(rel) { rel =>
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,46 +441,67 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
* `spark.connect.session.planCache.enabled` is true.
* @param rel
* The relation to transform.
* @param cachePlan
* Whether to cache the result logical plan.
* @param transform
* Function to transform the relation into a logical plan.
* @return
* The logical plan.
*/
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
private[connect] def usePlanCache(rel: proto.Relation)(
transform: proto.Relation => LogicalPlan): LogicalPlan = {
val planCacheEnabled = Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
// We only cache plans that have a plan ID.
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId

def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
cache.put(rel, plan)
case _ =>
}
val cachedPlan = planCache match {
case Some(cache) if planCacheEnabled(rel) =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
cachedPlan.getOrElse(transform(rel))
}

getPlanCache(rel)
.getOrElse({
val plan = transform(rel)
if (cachePlan) {
putPlanCache(rel, plan)
/**
* Update the plan cache with the supplied data frame.
*
* @param rel
* A proto.Relation that is used as the key for the cache.
* @param df
* A data frame containing the corresponding logical plan.
* @return
* The supplied data frame is returned.
*/
private[connect] def updatePlanCache(rel: proto.Relation, df: DataFrame): DataFrame = {
if (planCache.isDefined && planCacheEnabled(rel)) {
val plan = if (df.queryExecution.isLazyAnalysis) {
// Try to cache the unanalyzed plan if the plan is intended to be lazily analyzed.
if (planCache.get.getIfPresent(rel) == null) {
Some(df.queryExecution.logical)
} else {
None
}
plan
})
} else if (df.queryExecution.logical.analyzed) {
// The plan was analyzed during transformation or the cache was hit.
if (planCache.get.getIfPresent(rel) == null) {
Some(df.queryExecution.analyzed)
} else {
None
}
} else {
// Being not `isLazyAnalysis` and not analyzed implies that the plan is not in the cache.
Some(df.queryExecution.analyzed)
}
plan.foreach(p => planCache.get.put(rel, p))
}
df
}

// Return true if the plan cache is enabled for the session and the relation.
private def planCacheEnabled(rel: proto.Relation): Boolean = {
// We only cache plans that have a plan ID.
rel.hasCommon && rel.getCommon.hasPlanId &&
Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
}

// For testing. Expose the plan cache for testing purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ private[connect] class SparkConnectAnalyzeHandler(
val session = sessionHolder.session
val builder = proto.AnalyzePlanResponse.newBuilder()

def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)
def transformRelation(rel: proto.Relation) = planner.transformRelation(rel)

request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
val schema = Dataset
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
val rel = request.getSchema.getPlan.getRoot
val schema = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
Expand All @@ -73,8 +74,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
val queryExecution = Dataset
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
val rel = request.getExplain.getPlan.getRoot
val queryExecution = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
Expand All @@ -96,8 +98,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
val schema = Dataset
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
val rel = request.getTreeString.getPlan.getRoot
val schema = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
Expand All @@ -111,8 +114,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
val isLocal = Dataset
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
val rel = request.getIsLocal.getPlan.getRoot
val isLocal = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
Expand All @@ -121,8 +125,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
val isStreaming = Dataset
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
val rel = request.getIsStreaming.getPlan.getRoot
val isStreaming = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
Expand All @@ -131,8 +136,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
val inputFiles = Dataset
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
val rel = request.getInputFiles.getPlan.getRoot
val inputFiles = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
Expand All @@ -156,29 +162,37 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
val targetRel = request.getSameSemantics.getTargetPlan.getRoot
val target = sessionHolder
.updatePlanCache(targetRel, Dataset.ofRows(session, transformRelation(targetRel)))
val otherRel = request.getSameSemantics.getOtherPlan.getRoot
val other = sessionHolder
.updatePlanCache(otherRel, Dataset.ofRows(session, transformRelation(otherRel)))
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val semanticHash = Dataset
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
val rel = request.getSemanticHash.getPlan.getRoot
val semanticHash = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
.semanticHash()
builder.setSemanticHash(
proto.AnalyzePlanResponse.SemanticHash
.newBuilder()
.setResult(semanticHash))

case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getPersist.getRelation))
val rel = request.getPersist.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
if (request.getPersist.hasStorageLevel) {
target.persist(
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
Expand All @@ -188,8 +202,12 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
val rel = request.getUnpersist.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
if (request.getUnpersist.hasBlocking) {
target.unpersist(request.getUnpersist.getBlocking)
} else {
Expand All @@ -198,8 +216,12 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
val target = Dataset
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
val rel = request.getGetStorageLevel.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
val storageLevel = target.storageLevel
builder.setGetStorageLevel(
proto.AnalyzePlanResponse.GetStorageLevel
Expand Down
Loading

0 comments on commit 7b7430f

Please sign in to comment.