Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50903][CONNECT] Update plan cache entries after analysis #49584

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
} else {
DoNotCleanup
}
val rel = request.getPlan.getRoot
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
tracker,
shuffleCleanupMode)
sessionHolder
.updatePlanCache(
rel,
Dataset.ofRows(session, planner.transformRelation(rel), tracker, shuffleCleanupMode))
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,16 @@ class SparkConnectPlanner(
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

/**
* The root of the query plan is a relation and we apply the transformations to it. The resolved
* logical plan will not get cached. If the result needs to be cached, use
* `transformRelation(rel, cachePlan = true)` instead.
* @param rel
* The relation to transform.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation): LogicalPlan =
transformRelation(rel, cachePlan = false)

/**
* The root of the query plan is a relation and we apply the transformations to it.
* @param rel
* The relation to transform.
* @param cachePlan
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
* upon by further dataset transformation. The default is false.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
def transformRelation(rel: proto.Relation): LogicalPlan = {
sessionHolder.usePlanCache(rel) { rel =>
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,46 +441,53 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
* `spark.connect.session.planCache.enabled` is true.
* @param rel
* The relation to transform.
* @param cachePlan
* Whether to cache the result logical plan.
* @param transform
* Function to transform the relation into a logical plan.
* @return
* The logical plan.
*/
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
private[connect] def usePlanCache(rel: proto.Relation)(
transform: proto.Relation => LogicalPlan): LogicalPlan = {
val planCacheEnabled = Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
// We only cache plans that have a plan ID.
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId

def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
cache.put(rel, plan)
case _ =>
val cachedPlan = planCache match {
case Some(cache) if planCacheEnabled(rel) =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
cachedPlan.getOrElse(transform(rel))
}

/**
* Update the plan cache with the supplied data frame.
*
* @param rel
* A proto.Relation that is used as the key for the cache.
* @param df
* A data frame containing the corresponding logical plan.
* @return
* The supplied data frame is returned.
*/
private[connect] def updatePlanCache(rel: proto.Relation, df: DataFrame): DataFrame = {
if (planCache.isDefined && planCacheEnabled(rel)) {
if (df.queryExecution.isLazyAnalysis) {
planCache.get.get(rel, { () => df.queryExecution.logical })
} else {
planCache.get.get(rel, { () => df.queryExecution.analyzed })
}
}
df
}

getPlanCache(rel)
.getOrElse({
val plan = transform(rel)
if (cachePlan) {
putPlanCache(rel, plan)
}
plan
})
// Return true if the plan cache is enabled for the session and the relation.
private def planCacheEnabled(rel: proto.Relation): Boolean = {
// We only cache plans that have a plan ID.
rel.hasCommon && rel.getCommon.hasPlanId &&
Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
}

// For testing. Expose the plan cache for testing purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ private[connect] class SparkConnectAnalyzeHandler(
val session = sessionHolder.session
val builder = proto.AnalyzePlanResponse.newBuilder()

def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)
def transformRelation(rel: proto.Relation) = planner.transformRelation(rel)

request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
val schema = Dataset
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
val rel = request.getSchema.getPlan.getRoot
val schema = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
Expand All @@ -73,8 +74,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
val queryExecution = Dataset
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
val rel = request.getExplain.getPlan.getRoot
val queryExecution = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
Expand All @@ -96,8 +98,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
val schema = Dataset
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
val rel = request.getTreeString.getPlan.getRoot
val schema = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
Expand All @@ -111,8 +114,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
val isLocal = Dataset
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
val rel = request.getIsLocal.getPlan.getRoot
val isLocal = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
Expand All @@ -121,8 +125,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
val isStreaming = Dataset
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
val rel = request.getIsStreaming.getPlan.getRoot
val isStreaming = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
Expand All @@ -131,8 +136,9 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
val inputFiles = Dataset
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
val rel = request.getInputFiles.getPlan.getRoot
val inputFiles = sessionHolder
.updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel)))
.inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
Expand All @@ -156,29 +162,37 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
val targetRel = request.getSameSemantics.getTargetPlan.getRoot
val target = sessionHolder
.updatePlanCache(targetRel, Dataset.ofRows(session, transformRelation(targetRel)))
val otherRel = request.getSameSemantics.getOtherPlan.getRoot
val other = sessionHolder
.updatePlanCache(otherRel, Dataset.ofRows(session, transformRelation(otherRel)))
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val semanticHash = Dataset
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
val rel = request.getSemanticHash.getPlan.getRoot
val semanticHash = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
.semanticHash()
builder.setSemanticHash(
proto.AnalyzePlanResponse.SemanticHash
.newBuilder()
.setResult(semanticHash))

case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getPersist.getRelation))
val rel = request.getPersist.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
if (request.getPersist.hasStorageLevel) {
target.persist(
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
Expand All @@ -188,8 +202,12 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
val rel = request.getUnpersist.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
if (request.getUnpersist.hasBlocking) {
target.unpersist(request.getUnpersist.getBlocking)
} else {
Expand All @@ -198,8 +216,12 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
val target = Dataset
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
val rel = request.getGetStorageLevel.getRelation
val target = sessionHolder
.updatePlanCache(
rel,
Dataset
.ofRows(session, transformRelation(rel)))
val storageLevel = target.storageLevel
builder.setGetStorageLevel(
proto.AnalyzePlanResponse.GetStorageLevel
Expand Down
Loading
Loading