diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5b56b7079a897..b6cbb68b1843a 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -32,7 +32,7 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException} import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.connect.proto import org.apache.spark.internal.{Logging, LogKeys, MDC} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput @@ -450,14 +450,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( transform: proto.Relation => LogicalPlan): LogicalPlan = { - val planCacheEnabled = Option(session) - .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) // We only cache plans that have a plan ID. - val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId + val planCacheEnabled = rel.hasCommon && rel.getCommon.hasPlanId && + Option(session) + .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => + case Some(cache) if planCacheEnabled => Option(cache.getIfPresent(rel)) match { case Some(plan) => logDebug(s"Using cached plan for relation '$rel': $plan") @@ -466,20 +466,32 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } case _ => None } - def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = + def putPlanCache(rel: proto.Relation, plan: LogicalPlan): LogicalPlan = { planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - cache.put(rel, plan) - case _ => + case Some(cache) if planCacheEnabled => + val analyzedPlan = if (plan.analyzed) { + plan + } else { + // Make sure that the plan is fully analyzed before being cached. + Dataset.ofRows(session, plan).logicalPlan + } + cache.put(rel, analyzedPlan) + + // If the plan cache is enabled, always return the analyzed plan. + assert(analyzedPlan.analyzed) + analyzedPlan + case _ => plan } + } getPlanCache(rel) .getOrElse({ val plan = transform(rel) if (cachePlan) { putPlanCache(rel, plan) + } else { + plan } - plan }) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 21f84291a2f07..7d53c26ccb918 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -314,6 +314,8 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { case Some(expectedCachedRelations) => val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala assert(cachedRelations.size == expectedCachedRelations.size) + val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala + cachedLogicalPlans.foreach(plan => assert(plan.analyzed)) expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) case None => assert(sessionHolder.getPlanCache.isEmpty) }