From e948a7cdabf34ed94dc8e16d77275184aa4f867f Mon Sep 17 00:00:00 2001 From: changgyoopark-db Date: Tue, 21 Jan 2025 11:34:03 +0100 Subject: [PATCH] Impl --- python/pyspark/sql/tests/test_dataframe.py | 4 ++- .../sql/connect/service/SessionHolder.scala | 34 +++++++++++++------ .../SparkConnectSessionHolderSuite.scala | 2 ++ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 706b8c0a8be81..27dafca1b80c0 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -440,7 +440,9 @@ def test_extended_hint_types(self): self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df)) with io.StringIO() as buf, redirect_stdout(buf): - hinted_df.explain(True) + # the plan cache may hold a fully analyzed plan + with self.sql_conf({"spark.connect.session.planCache.enabled": False}): + hinted_df.explain(True) explain_output = buf.getvalue() self.assertGreaterEqual(explain_output.count("1.2345"), 1) self.assertGreaterEqual(explain_output.count("what"), 1) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5b56b7079a897..2e5ffe2cd5fd8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -32,7 +32,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException} import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.connect.proto import org.apache.spark.internal.{Logging, LogKeys, MDC} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput @@ -450,14 +450,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( transform: proto.Relation => LogicalPlan): LogicalPlan = { - val planCacheEnabled = Option(session) - .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) // We only cache plans that have a plan ID. - val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId + val planCacheEnabled = rel.hasCommon && rel.getCommon.hasPlanId && + Option(session) + .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => + case Some(cache) if planCacheEnabled => Option(cache.getIfPresent(rel)) match { case Some(plan) => logDebug(s"Using cached plan for relation '$rel': $plan") @@ -466,11 +466,24 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } case _ => None } - def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = + def putPlanCache(rel: proto.Relation, plan: LogicalPlan): LogicalPlan = planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - cache.put(rel, plan) - case _ => + case Some(cache) if planCacheEnabled => + val analyzedPlan = if (plan.analyzed) { + plan + } else { + val qe = Dataset.ofRows(session, plan).queryExecution + if (qe.isLazyAnalysis) { + // The plan is intended to be lazily analyzed. + plan + } else { + // Make sure that the plan is fully analyzed before being cached. + qe.analyzed + } + } + cache.put(rel, analyzedPlan) + analyzedPlan + case _ => plan } getPlanCache(rel) @@ -478,8 +491,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio val plan = transform(rel) if (cachePlan) { putPlanCache(rel, plan) + } else { + plan } - plan }) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 21f84291a2f07..7d53c26ccb918 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -314,6 +314,8 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { case Some(expectedCachedRelations) => val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala assert(cachedRelations.size == expectedCachedRelations.size) + val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala + cachedLogicalPlans.foreach(plan => assert(plan.analyzed)) expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) case None => assert(sessionHolder.getPlanCache.isEmpty) }