diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 6ba100af1bb9a..e94e865873937 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -273,4 +273,22 @@ object Connect { .version("4.0.0") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("2s") + + val CONNECT_SESSION_PLAN_CACHE_SIZE = + buildStaticConf("spark.connect.session.planCache.maxSize") + .doc("Sets the maximum number of cached resolved logical plans in Spark Connect Session." + + " If set to a value less or equal than zero will disable the plan cache.") + .version("4.0.0") + .intConf + .createWithDefault(5) + + val CONNECT_SESSION_PLAN_CACHE_ENABLED = + buildConf("spark.connect.session.planCache.enabled") + .doc("When true, the cache of resolved logical plans is enabled if" + + s" '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is greater than zero." + + s" When false, the cache is disabled even if '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" + + " greater than zero. The caching is best-effort and not guaranteed.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5e7f3b74c2997..d8eb044e4f942 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -115,95 +115,118 @@ class SparkConnectPlanner( private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - // The root of the query plan is a relation and we apply the transformations to it. - def transformRelation(rel: proto.Relation): LogicalPlan = { - val plan = rel.getRelTypeCase match { - // DataFrame API - case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) - case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString) - case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) - case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) - case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) - case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit) - case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) - case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail) - case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin) - case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin) - case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate) - case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp) - case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) - case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop) - case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) - case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) - case proto.Relation.RelTypeCase.WITH_RELATIONS - if isValidSQLWithRefs(rel.getWithRelations) => - transformSqlWithRefs(rel.getWithRelations) - case proto.Relation.RelTypeCase.LOCAL_RELATION => - transformLocalRelation(rel.getLocalRelation) - case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) - case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange) - case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => - transformSubqueryAlias(rel.getSubqueryAlias) - case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) - case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa) - case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa) - case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace) - case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) - case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe) - case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov) - case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr) - case proto.Relation.RelTypeCase.APPROX_QUANTILE => - transformStatApproxQuantile(rel.getApproxQuantile) - case proto.Relation.RelTypeCase.CROSSTAB => - transformStatCrosstab(rel.getCrosstab) - case proto.Relation.RelTypeCase.FREQ_ITEMS => transformStatFreqItems(rel.getFreqItems) - case proto.Relation.RelTypeCase.SAMPLE_BY => - transformStatSampleBy(rel.getSampleBy) - case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema) - case proto.Relation.RelTypeCase.TO_DF => - transformToDF(rel.getToDf) - case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED => - transformWithColumnsRenamed(rel.getWithColumnsRenamed) - case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) - case proto.Relation.RelTypeCase.WITH_WATERMARK => - transformWithWatermark(rel.getWithWatermark) - case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION => - transformCachedLocalRelation(rel.getCachedLocalRelation) - case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) - case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) - case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => - transformRepartitionByExpression(rel.getRepartitionByExpression) - case proto.Relation.RelTypeCase.MAP_PARTITIONS => - transformMapPartitions(rel.getMapPartitions) - case proto.Relation.RelTypeCase.GROUP_MAP => - transformGroupMap(rel.getGroupMap) - case proto.Relation.RelTypeCase.CO_GROUP_MAP => - transformCoGroupMap(rel.getCoGroupMap) - case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE => - transformApplyInPandasWithState(rel.getApplyInPandasWithState) - case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION => - transformCommonInlineUserDefinedTableFunction(rel.getCommonInlineUserDefinedTableFunction) - case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => - transformCachedRemoteRelation(rel.getCachedRemoteRelation) - case proto.Relation.RelTypeCase.COLLECT_METRICS => - transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) - case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) - case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => - throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") - - // Catalog API (internal-only) - case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog) - - // Handle plugins for Spark Connect Relation types. - case proto.Relation.RelTypeCase.EXTENSION => - transformRelationPlugin(rel.getExtension) - case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") - } - - if (rel.hasCommon && rel.getCommon.hasPlanId) { - plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId) - } - plan + /** + * The root of the query plan is a relation and we apply the transformations to it. The resolved + * logical plan will not get cached. If the result needs to be cached, use + * `transformRelation(rel, cachePlan = true)` instead. + * @param rel + * The relation to transform. + * @return + * The resolved logical plan. + */ + def transformRelation(rel: proto.Relation): LogicalPlan = + transformRelation(rel, cachePlan = false) + + /** + * The root of the query plan is a relation and we apply the transformations to it. + * @param rel + * The relation to transform. + * @param cachePlan + * Set to true for a performance optimization, if the plan is likely to be reused, e.g. built + * upon by further dataset transformation. The default is false. + * @return + * The resolved logical plan. + */ + def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = { + sessionHolder.usePlanCache(rel, cachePlan) { rel => + val plan = rel.getRelTypeCase match { + // DataFrame API + case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) + case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString) + case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead) + case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject) + case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) + case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit) + case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) + case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail) + case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin) + case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin) + case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate) + case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp) + case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) + case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop) + case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate) + case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) + case proto.Relation.RelTypeCase.WITH_RELATIONS + if isValidSQLWithRefs(rel.getWithRelations) => + transformSqlWithRefs(rel.getWithRelations) + case proto.Relation.RelTypeCase.LOCAL_RELATION => + transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) + case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange) + case proto.Relation.RelTypeCase.SUBQUERY_ALIAS => + transformSubqueryAlias(rel.getSubqueryAlias) + case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) + case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa) + case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa) + case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace) + case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) + case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe) + case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov) + case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr) + case proto.Relation.RelTypeCase.APPROX_QUANTILE => + transformStatApproxQuantile(rel.getApproxQuantile) + case proto.Relation.RelTypeCase.CROSSTAB => + transformStatCrosstab(rel.getCrosstab) + case proto.Relation.RelTypeCase.FREQ_ITEMS => transformStatFreqItems(rel.getFreqItems) + case proto.Relation.RelTypeCase.SAMPLE_BY => + transformStatSampleBy(rel.getSampleBy) + case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema) + case proto.Relation.RelTypeCase.TO_DF => + transformToDF(rel.getToDf) + case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED => + transformWithColumnsRenamed(rel.getWithColumnsRenamed) + case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) + case proto.Relation.RelTypeCase.WITH_WATERMARK => + transformWithWatermark(rel.getWithWatermark) + case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION => + transformCachedLocalRelation(rel.getCachedLocalRelation) + case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) + case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) + case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => + transformRepartitionByExpression(rel.getRepartitionByExpression) + case proto.Relation.RelTypeCase.MAP_PARTITIONS => + transformMapPartitions(rel.getMapPartitions) + case proto.Relation.RelTypeCase.GROUP_MAP => + transformGroupMap(rel.getGroupMap) + case proto.Relation.RelTypeCase.CO_GROUP_MAP => + transformCoGroupMap(rel.getCoGroupMap) + case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE => + transformApplyInPandasWithState(rel.getApplyInPandasWithState) + case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION => + transformCommonInlineUserDefinedTableFunction( + rel.getCommonInlineUserDefinedTableFunction) + case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => + transformCachedRemoteRelation(rel.getCachedRemoteRelation) + case proto.Relation.RelTypeCase.COLLECT_METRICS => + transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId) + case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) + case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => + throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") + + // Catalog API (internal-only) + case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog) + + // Handle plugins for Spark Connect Relation types. + case proto.Relation.RelTypeCase.EXTENSION => + transformRelationPlugin(rel.getExtension) + case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") + } + if (rel.hasCommon && rel.getCommon.hasPlanId) { + plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId) + } + plan + } } @DeveloperApi diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 306b891485834..3dad57209982d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -27,14 +27,17 @@ import scala.jdk.CollectionConverters._ import scala.util.Try import com.google.common.base.Ticker -import com.google.common.cache.CacheBuilder +import com.google.common.cache.{Cache, CacheBuilder} -import org.apache.spark.{SparkException, SparkSQLException} +import org.apache.spark.{SparkEnv, SparkException, SparkSQLException} import org.apache.spark.api.python.PythonFunction.PythonAccumulator +import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} @@ -50,6 +53,27 @@ case class SessionKey(userId: String, sessionId: String) case class SessionHolder(userId: String, sessionId: String, session: SparkSession) extends Logging { + // Cache which stores recently resolved logical plans to improve the performance of plan analysis. + // Only plans that explicitly specify "cachePlan = true" in transformRelation will be cached. + // Analyzing a large plan may be expensive, and it is not uncommon to build the plan step-by-step + // with several analysis during the process. This cache aids the recursive analysis process by + // memorizing `LogicalPlan`s which may be a sub-tree in a subsequent plan. + private lazy val planCache: Option[Cache[proto.Relation, LogicalPlan]] = { + if (SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE) <= 0) { + logWarning( + s"Session plan cache is disabled due to non-positive cache size." + + s" Current value of '${Connect.CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" + + s" ${SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE)}.") + None + } else { + Some( + CacheBuilder + .newBuilder() + .maximumSize(SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE)) + .build[proto.Relation, LogicalPlan]()) + } + } + // Time when the session was started. private val startTimeMs: Long = System.currentTimeMillis() @@ -388,6 +412,57 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] val pythonAccumulator: Option[PythonAccumulator] = Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption + + /** + * Transform a relation into a logical plan, using the plan cache if enabled. The plan cache is + * enable only if `spark.connect.session.planCache.maxSize` is greater than zero AND + * `spark.connect.session.planCache.enabled` is true. + * @param rel + * The relation to transform. + * @param cachePlan + * Whether to cache the result logical plan. + * @param transform + * Function to transform the relation into a logical plan. + * @return + * The logical plan. + */ + private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( + transform: proto.Relation => LogicalPlan): LogicalPlan = { + val planCacheEnabled = + Option(session).forall(_.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) + // We only cache plans that have a plan ID. + val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId + + def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = + planCache match { + case Some(cache) if planCacheEnabled && hasPlanId => + Option(cache.getIfPresent(rel)) match { + case Some(plan) => + logDebug(s"Using cached plan for relation '$rel': $plan") + Some(plan) + case None => None + } + case _ => None + } + def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = + planCache match { + case Some(cache) if planCacheEnabled && hasPlanId => + cache.put(rel, plan) + case _ => + } + + getPlanCache(rel) + .getOrElse({ + val plan = transform(rel) + if (cachePlan) { + putPlanCache(rel, plan) + } + plan + }) + } + + // For testing. Expose the plan cache for testing purposes. + private[service] def getPlanCache: Option[Cache[proto.Relation, LogicalPlan]] = planCache } object SessionHolder { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 3dfd29d6a8c66..6c5d95ac67d3d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -58,10 +58,12 @@ private[connect] class SparkConnectAnalyzeHandler( val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() + def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) + request.getAnalyzeCase match { case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => val schema = Dataset - .ofRows(session, planner.transformRelation(request.getSchema.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getSchema.getPlan.getRoot)) .schema builder.setSchema( proto.AnalyzePlanResponse.Schema @@ -71,7 +73,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => val queryExecution = Dataset - .ofRows(session, planner.transformRelation(request.getExplain.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getExplain.getPlan.getRoot)) .queryExecution val explainString = request.getExplain.getExplainMode match { case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE => @@ -94,7 +96,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => val schema = Dataset - .ofRows(session, planner.transformRelation(request.getTreeString.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot)) .schema val treeString = if (request.getTreeString.hasLevel) { schema.treeString(request.getTreeString.getLevel) @@ -109,7 +111,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => val isLocal = Dataset - .ofRows(session, planner.transformRelation(request.getIsLocal.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot)) .isLocal builder.setIsLocal( proto.AnalyzePlanResponse.IsLocal @@ -119,7 +121,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => val isStreaming = Dataset - .ofRows(session, planner.transformRelation(request.getIsStreaming.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot)) .isStreaming builder.setIsStreaming( proto.AnalyzePlanResponse.IsStreaming @@ -129,7 +131,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => val inputFiles = Dataset - .ofRows(session, planner.transformRelation(request.getInputFiles.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot)) .inputFiles builder.setInputFiles( proto.AnalyzePlanResponse.InputFiles @@ -155,10 +157,10 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS => val target = Dataset.ofRows( session, - planner.transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) + transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) val other = Dataset.ofRows( session, - planner.transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) + transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) builder.setSameSemantics( proto.AnalyzePlanResponse.SameSemantics .newBuilder() @@ -166,7 +168,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH => val semanticHash = Dataset - .ofRows(session, planner.transformRelation(request.getSemanticHash.getPlan.getRoot)) + .ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot)) .semanticHash() builder.setSemanticHash( proto.AnalyzePlanResponse.SemanticHash @@ -175,7 +177,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST => val target = Dataset - .ofRows(session, planner.transformRelation(request.getPersist.getRelation)) + .ofRows(session, transformRelation(request.getPersist.getRelation)) if (request.getPersist.hasStorageLevel) { target.persist( StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel)) @@ -186,7 +188,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST => val target = Dataset - .ofRows(session, planner.transformRelation(request.getUnpersist.getRelation)) + .ofRows(session, transformRelation(request.getUnpersist.getRelation)) if (request.getUnpersist.hasBlocking) { target.unpersist(request.getUnpersist.getBlocking) } else { @@ -196,7 +198,7 @@ private[connect] class SparkConnectAnalyzeHandler( case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL => val target = Dataset - .ofRows(session, planner.transformRelation(request.getGetStorageLevel.getRelation)) + .ofRows(session, transformRelation(request.getGetStorageLevel.getRelation)) val storageLevel = target.storageLevel builder.setGetStorageLevel( proto.AnalyzePlanResponse.GetStorageLevel diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index bb51b0a798207..62b4151aad8a5 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -23,14 +23,18 @@ import java.nio.file.Files import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.sys.process.Process +import scala.util.Random import com.google.common.collect.Lists import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkEnv import org.apache.spark.api.python.SimplePythonFunction +import org.apache.spark.connect.proto import org.apache.spark.sql.IntegratedUDFTestUtils import org.apache.spark.sql.connect.common.InvalidPlanInput -import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, StreamingForeachBatchHelper} +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, SparkConnectPlanner, StreamingForeachBatchHelper} import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper.RunnerCleaner import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ArrayImplicits._ @@ -289,4 +293,123 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { spark.streams.listListeners().foreach(spark.streams.removeListener) } } + + private def buildRelation(query: String) = { + proto.Relation + .newBuilder() + .setSql( + proto.SQL + .newBuilder() + .setQuery(query) + .build()) + .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) + .build() + } + + private def assertPlanCache( + sessionHolder: SessionHolder, + optionExpectedCachedRelations: Option[Set[proto.Relation]]) = { + optionExpectedCachedRelations match { + case Some(expectedCachedRelations) => + val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala + assert(cachedRelations.size == expectedCachedRelations.size) + expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) + case None => assert(sessionHolder.getPlanCache.isEmpty) + } + } + + test("Test session plan cache") { + val sessionHolder = SessionHolder.forTesting(spark) + try { + // Set cache size to 2 + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 2) + val planner = new SparkConnectPlanner(sessionHolder) + + val random1 = buildRelation("select 1") + val random2 = buildRelation("select 2") + val random3 = buildRelation("select 3") + val query1 = proto.Relation.newBuilder + .setLimit( + proto.Limit.newBuilder + .setLimit(10) + .setInput( + proto.Relation + .newBuilder() + .setRange(proto.Range.newBuilder().setStart(0).setStep(1).setEnd(20)) + .build())) + .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) + .build() + val query2 = proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(5).setInput(query1)) + .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) + .build() + + // If cachePlan is false, the cache is still empty. + planner.transformRelation(random1, cachePlan = false) + assertPlanCache(sessionHolder, Some(Set())) + + // Put a random entry in cache. + planner.transformRelation(random1, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(random1))) + + // Put another random entry in cache. + planner.transformRelation(random2, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(random1, random2))) + + // Analyze query1. We only cache the root relation, and the random1 is evicted. + planner.transformRelation(query1, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(random2, query1))) + + // Put another random entry in cache. + planner.transformRelation(random3, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(query1, random3))) + + // Analyze query2. As query1 is accessed during the process, it should be in the cache. + planner.transformRelation(query2, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set(query1, query2))) + } finally { + // Set back to default value. + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) + } + } + + test("Test session plan cache - cache size zero or negative") { + val sessionHolder = SessionHolder.forTesting(spark) + try { + // Set cache size to -1 + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, -1) + val planner = new SparkConnectPlanner(sessionHolder) + + val query = buildRelation("select 1") + + // If cachePlan is false, the cache is still None. + planner.transformRelation(query, cachePlan = false) + assertPlanCache(sessionHolder, None) + + // Even if we specify "cachePlan = true", the cache is still None. + planner.transformRelation(query, cachePlan = true) + assertPlanCache(sessionHolder, None) + } finally { + // Set back to default value. + SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) + } + } + + test("Test session plan cache - disabled") { + val sessionHolder = SessionHolder.forTesting(spark) + // Disable plan cache of the session + sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, false) + val planner = new SparkConnectPlanner(sessionHolder) + + val query = buildRelation("select 1") + + // If cachePlan is false, the cache is still empty. + // Although the cache is created as cache size is greater than zero, it won't be used. + planner.transformRelation(query, cachePlan = false) + assertPlanCache(sessionHolder, Some(Set())) + + // Even if we specify "cachePlan = true", the cache is still empty. + planner.transformRelation(query, cachePlan = true) + assertPlanCache(sessionHolder, Some(Set())) + } }