diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 5e3499573e9d9..6a43c6b8f9f4d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -68,12 +68,12 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) } else { DoNotCleanup } + val rel = request.getPlan.getRoot val dataframe = - Dataset.ofRows( - sessionHolder.session, - planner.transformRelation(request.getPlan.getRoot, cachePlan = true), - tracker, - shuffleCleanupMode) + sessionHolder + .updatePlanCache( + rel, + Dataset.ofRows(session, planner.transformRelation(rel), tracker, shuffleCleanupMode)) responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(dataframe, responseObserver, executeHolder) responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe)) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 56824bbb4a417..58f4c7772f402 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -117,32 +117,16 @@ class SparkConnectPlanner( private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - /** - * The root of the query plan is a relation and we apply the transformations to it. The resolved - * logical plan will not get cached. If the result needs to be cached, use - * `transformRelation(rel, cachePlan = true)` instead. - * @param rel - * The relation to transform. - * @return - * The resolved logical plan. - */ - @DeveloperApi - def transformRelation(rel: proto.Relation): LogicalPlan = - transformRelation(rel, cachePlan = false) - /** * The root of the query plan is a relation and we apply the transformations to it. * @param rel * The relation to transform. - * @param cachePlan - * Set to true for a performance optimization, if the plan is likely to be reused, e.g. built - * upon by further dataset transformation. The default is false. * @return * The resolved logical plan. */ @DeveloperApi - def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = { - sessionHolder.usePlanCache(rel, cachePlan) { rel => + def transformRelation(rel: proto.Relation): LogicalPlan = { + sessionHolder.usePlanCache(rel) { rel => val plan = rel.getRelTypeCase match { // DataFrame API case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5b56b7079a897..7791aa9e269af 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -441,46 +441,53 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio * `spark.connect.session.planCache.enabled` is true. * @param rel * The relation to transform. - * @param cachePlan - * Whether to cache the result logical plan. * @param transform * Function to transform the relation into a logical plan. * @return * The logical plan. */ - private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( + private[connect] def usePlanCache(rel: proto.Relation)( transform: proto.Relation => LogicalPlan): LogicalPlan = { - val planCacheEnabled = Option(session) - .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) - // We only cache plans that have a plan ID. - val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId - - def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = - planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - Option(cache.getIfPresent(rel)) match { - case Some(plan) => - logDebug(s"Using cached plan for relation '$rel': $plan") - Some(plan) - case None => None - } - case _ => None - } - def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = - planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - cache.put(rel, plan) - case _ => + val cachedPlan = planCache match { + case Some(cache) if planCacheEnabled(rel) => + Option(cache.getIfPresent(rel)) match { + case Some(plan) => + logDebug(s"Using cached plan for relation '$rel': $plan") + Some(plan) + case None => None + } + case _ => None + } + cachedPlan.getOrElse(transform(rel)) + } + + /** + * Update the plan cache with the supplied data frame. + * + * @param rel + * A proto.Relation that is used as the key for the cache. + * @param df + * A data frame containing the corresponding logical plan. + * @return + * The supplied data frame is returned. + */ + private[connect] def updatePlanCache(rel: proto.Relation, df: DataFrame): DataFrame = { + if (planCache.isDefined && planCacheEnabled(rel)) { + if (df.queryExecution.isLazyAnalysis) { + planCache.get.get(rel, { () => df.queryExecution.logical }) + } else { + planCache.get.get(rel, { () => df.queryExecution.analyzed }) } + } + df + } - getPlanCache(rel) - .getOrElse({ - val plan = transform(rel) - if (cachePlan) { - putPlanCache(rel, plan) - } - plan - }) + // Return true if the plan cache is enabled for the session and the relation. + private def planCacheEnabled(rel: proto.Relation): Boolean = { + // We only cache plans that have a plan ID. + rel.hasCommon && rel.getCommon.hasPlanId && + Option(session) + .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) } // For testing. Expose the plan cache for testing purposes. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 8ca021c5be39e..c596fc786a395 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -59,12 +59,13 @@ private[connect] class SparkConnectAnalyzeHandler( val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() - def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) + def transformRelation(rel: proto.Relation) = planner.transformRelation(rel) request.getAnalyzeCase match { case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => - val schema = Dataset - .ofRows(session, transformRelation(request.getSchema.getPlan.getRoot)) + val rel = request.getSchema.getPlan.getRoot + val schema = sessionHolder + .updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel))) .schema builder.setSchema( proto.AnalyzePlanResponse.Schema @@ -73,8 +74,9 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => - val queryExecution = Dataset - .ofRows(session, transformRelation(request.getExplain.getPlan.getRoot)) + val rel = request.getExplain.getPlan.getRoot + val queryExecution = sessionHolder + .updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel))) .queryExecution val explainString = request.getExplain.getExplainMode match { case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE => @@ -96,8 +98,9 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => - val schema = Dataset - .ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot)) + val rel = request.getTreeString.getPlan.getRoot + val schema = sessionHolder + .updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel))) .schema val treeString = if (request.getTreeString.hasLevel) { schema.treeString(request.getTreeString.getLevel) @@ -111,8 +114,9 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => - val isLocal = Dataset - .ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot)) + val rel = request.getIsLocal.getPlan.getRoot + val isLocal = sessionHolder + .updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel))) .isLocal builder.setIsLocal( proto.AnalyzePlanResponse.IsLocal @@ -121,8 +125,9 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => - val isStreaming = Dataset - .ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot)) + val rel = request.getIsStreaming.getPlan.getRoot + val isStreaming = sessionHolder + .updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel))) .isStreaming builder.setIsStreaming( proto.AnalyzePlanResponse.IsStreaming @@ -131,8 +136,9 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => - val inputFiles = Dataset - .ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot)) + val rel = request.getInputFiles.getPlan.getRoot + val inputFiles = sessionHolder + .updatePlanCache(rel, Dataset.ofRows(session, transformRelation(rel))) .inputFiles builder.setInputFiles( proto.AnalyzePlanResponse.InputFiles @@ -156,20 +162,24 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS => - val target = Dataset.ofRows( - session, - transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) - val other = Dataset.ofRows( - session, - transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) + val targetRel = request.getSameSemantics.getTargetPlan.getRoot + val target = sessionHolder + .updatePlanCache(targetRel, Dataset.ofRows(session, transformRelation(targetRel))) + val otherRel = request.getSameSemantics.getOtherPlan.getRoot + val other = sessionHolder + .updatePlanCache(otherRel, Dataset.ofRows(session, transformRelation(otherRel))) builder.setSameSemantics( proto.AnalyzePlanResponse.SameSemantics .newBuilder() .setResult(target.sameSemantics(other))) case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH => - val semanticHash = Dataset - .ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot)) + val rel = request.getSemanticHash.getPlan.getRoot + val semanticHash = sessionHolder + .updatePlanCache( + rel, + Dataset + .ofRows(session, transformRelation(rel))) .semanticHash() builder.setSemanticHash( proto.AnalyzePlanResponse.SemanticHash @@ -177,8 +187,12 @@ private[connect] class SparkConnectAnalyzeHandler( .setResult(semanticHash)) case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST => - val target = Dataset - .ofRows(session, transformRelation(request.getPersist.getRelation)) + val rel = request.getPersist.getRelation + val target = sessionHolder + .updatePlanCache( + rel, + Dataset + .ofRows(session, transformRelation(rel))) if (request.getPersist.hasStorageLevel) { target.persist( StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel)) @@ -188,8 +202,12 @@ private[connect] class SparkConnectAnalyzeHandler( builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build()) case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST => - val target = Dataset - .ofRows(session, transformRelation(request.getUnpersist.getRelation)) + val rel = request.getUnpersist.getRelation + val target = sessionHolder + .updatePlanCache( + rel, + Dataset + .ofRows(session, transformRelation(rel))) if (request.getUnpersist.hasBlocking) { target.unpersist(request.getUnpersist.getBlocking) } else { @@ -198,8 +216,12 @@ private[connect] class SparkConnectAnalyzeHandler( builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build()) case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL => - val target = Dataset - .ofRows(session, transformRelation(request.getGetStorageLevel.getRelation)) + val rel = request.getGetStorageLevel.getRelation + val target = sessionHolder + .updatePlanCache( + rel, + Dataset + .ofRows(session, transformRelation(rel))) val storageLevel = target.storageLevel builder.setGetStorageLevel( proto.AnalyzePlanResponse.GetStorageLevel diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 21f84291a2f07..e53d6251df7b3 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkEnv import org.apache.spark.api.python.SimplePythonFunction import org.apache.spark.connect.proto -import org.apache.spark.sql.IntegratedUDFTestUtils +import org.apache.spark.sql.{Dataset, IntegratedUDFTestUtils} import org.apache.spark.sql.connect.SparkConnectTestUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect @@ -309,11 +309,14 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { private def assertPlanCache( sessionHolder: SessionHolder, - optionExpectedCachedRelations: Option[Set[proto.Relation]]) = { + optionExpectedCachedRelations: Option[Set[proto.Relation]], + expectAnalyzed: Boolean) = { optionExpectedCachedRelations match { case Some(expectedCachedRelations) => val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala assert(cachedRelations.size == expectedCachedRelations.size) + val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala + cachedLogicalPlans.foreach(plan => assert(plan.analyzed == expectAnalyzed)) expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) case None => assert(sessionHolder.getPlanCache.isEmpty) } @@ -345,29 +348,33 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) .build() - // If cachePlan is false, the cache is still empty. - planner.transformRelation(random1, cachePlan = false) - assertPlanCache(sessionHolder, Some(Set())) + // Transform the relation without analysis, the cache is still empty. + val random1Plan = planner.transformRelation(random1) + assertPlanCache(sessionHolder, Some(Set()), false) - // Put a random entry in cache. - planner.transformRelation(random1, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(random1))) + // Put a random entry in cache after analysis. + sessionHolder.updatePlanCache(random1, Dataset.ofRows(sessionHolder.session, random1Plan)) + assertPlanCache(sessionHolder, Some(Set(random1)), true) // Put another random entry in cache. - planner.transformRelation(random2, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(random1, random2))) + val random2Plan = planner.transformRelation(random2) + sessionHolder.updatePlanCache(random2, Dataset.ofRows(sessionHolder.session, random2Plan)) + assertPlanCache(sessionHolder, Some(Set(random1, random2)), true) // Analyze query1. We only cache the root relation, and the random1 is evicted. - planner.transformRelation(query1, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(random2, query1))) + val query1Plan = planner.transformRelation(query1) + sessionHolder.updatePlanCache(query1, Dataset.ofRows(sessionHolder.session, query1Plan)) + assertPlanCache(sessionHolder, Some(Set(random2, query1)), true) // Put another random entry in cache. - planner.transformRelation(random3, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(query1, random3))) + val random3Plan = planner.transformRelation(random3) + sessionHolder.updatePlanCache(random3, Dataset.ofRows(sessionHolder.session, random3Plan)) + assertPlanCache(sessionHolder, Some(Set(query1, random3)), true) // Analyze query2. As query1 is accessed during the process, it should be in the cache. - planner.transformRelation(query2, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(query1, query2))) + val query2Plan = planner.transformRelation(query2) + sessionHolder.updatePlanCache(query2, Dataset.ofRows(sessionHolder.session, query2Plan)) + assertPlanCache(sessionHolder, Some(Set(query1, query2)), true) } finally { // Set back to default value. SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) @@ -383,13 +390,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { val query = buildRelation("select 1") - // If cachePlan is false, the cache is still None. - planner.transformRelation(query, cachePlan = false) - assertPlanCache(sessionHolder, None) - - // Even if we specify "cachePlan = true", the cache is still None. - planner.transformRelation(query, cachePlan = true) - assertPlanCache(sessionHolder, None) + // The cache must be empty. + val plan = planner.transformRelation(query) + sessionHolder.updatePlanCache(query, Dataset.ofRows(sessionHolder.session, plan)) + assertPlanCache(sessionHolder, None, true) } finally { // Set back to default value. SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) @@ -404,14 +408,12 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { val query = buildRelation("select 1") - // If cachePlan is false, the cache is still empty. - // Although the cache is created as cache size is greater than zero, it won't be used. - planner.transformRelation(query, cachePlan = false) - assertPlanCache(sessionHolder, Some(Set())) + // The cache must be empty. + val plan = planner.transformRelation(query) + sessionHolder.updatePlanCache(query, Dataset.ofRows(sessionHolder.session, plan)) + assertPlanCache(sessionHolder, Some(Set()), true) - // Even if we specify "cachePlan = true", the cache is still empty. - planner.transformRelation(query, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set())) + sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED.key, true) } test("Test duplicate operation IDs") {