From c3176a79de62710e5ec6dec18f7f7bc277066e12 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Tue, 15 Oct 2024 08:13:10 +0200 Subject: [PATCH 01/31] [SPARK-49451][SQL][FOLLOW-UP] Improve duplicate key exception test ### What changes were proposed in this pull request? This test improves a unit test case where json strings with duplicate keys are prohibited by checking the cause of the exception instead of just the root exception. ### Why are the changes needed? Earlier, the test only checked the top error class but not the cause of the error which should be `VARIANT_DUPLICATE_KEY`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ### Was this patch authored or co-authored using generative AI tooling? NA Closes #48464 from harshmotw-db/harshmotw-db/minor_test_fix. Authored-by: Harsh Motwani Signed-off-by: Max Gekk --- .../spark/sql/VariantEndToEndSuite.scala | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 19d4ac23709b6..fe5c6ef004920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql -import org.apache.spark.SparkThrowable +import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.sql.QueryTest.sameRows import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -359,16 +359,24 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c') assert(actual === new VariantVal(expectedValue, expectedMetadata)) } - withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { - val df = Seq(json).toDF("j") - .selectExpr("from_json(j,'variant')") - checkError( - exception = intercept[SparkThrowable] { + // Check whether the parse_json and from_json expressions throw the correct exception. + Seq("from_json(j, 'variant')", "parse_json(j)").foreach { expr => + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { + val df = Seq(json).toDF("j").selectExpr(expr) + val exception = intercept[SparkException] { df.collect() - }, - condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", - parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") - ) + } + checkError( + exception = exception, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") + ) + checkError( + exception = exception.getCause.asInstanceOf[SparkRuntimeException], + condition = "VARIANT_DUPLICATE_KEY", + parameters = Map("key" -> "a") + ) + } } } } From 2ccdabaa79ac4233db0a97c8bd79cbd55cb2d773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Tue, 15 Oct 2024 21:20:47 +0800 Subject: [PATCH 02/31] [SPARK-49956] Disabled collations with collect_set expression MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In this PR, I propose to disallow collated strings in `collect_set` expression. ### Why are the changes needed? Proposed changes are necessary in order to achieve correct behavior of the expressions mentioned above. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This patch was tested by modifying existing test case in `CollationSQLExpressionSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48456 from vladanvasi-db/vladanvasi-db/collect-set-collated-disablement. Authored-by: Vladan Vasić Signed-off-by: Wenchen Fan --- .../expressions/aggregate/collect.scala | 8 ++++--- .../sql/CollationSQLExpressionsSuite.scala | 22 +++++++++++++------ .../spark/sql/DataFrameAggregateSuite.scala | 2 +- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 0a4882bfada17..3270c6e87e2cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils, UnsafeRowUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.types._ import org.apache.spark.util.BoundedPriorityQueue @@ -145,6 +145,7 @@ case class CollectList( """, group = "agg_funcs", since = "2.0.0") +// TODO: Make CollectSet collation aware case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0, @@ -178,14 +179,15 @@ case class CollectSet( } override def checkInputDataTypes(): TypeCheckResult = { - if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { + if (!child.dataType.existsRecursively(_.isInstanceOf[MapType]) && + UnsafeRowUtils.isBinaryStable(child.dataType)) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( errorSubClass = "UNSUPPORTED_INPUT_TYPE", messageParameters = Map( "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(MapType) + "dataType" -> (s"${toSQLType(MapType)} " + "or \"COLLATED STRING\"") ) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index ce6818652d2b5..d568cd77050fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -2819,16 +2819,24 @@ class CollationSQLExpressionsSuite } } - test("collect_set supports collation") { + test("collect_set does not support collation") { val collation = "UNICODE" val query = s"SELECT collect_set(col) FROM VALUES ('a'), ('b'), ('a') AS tab(col);" withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - val result = sql(query).collect().head.getSeq[String](0).toSet - val expected = Set("a", "b") - assert(result == expected) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) + checkError( + exception = intercept[AnalysisException] { + sql(query) + }, + condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "functionName" -> "`collect_set`", + "dataType" -> "\"MAP\" or \"COLLATED STRING\"", + "sqlExpr" -> "\"collect_set(col)\""), + context = ExpectedContext( + fragment = "collect_set(col)", + start = 7, + stop = 22)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e80c3b23a7db3..25f4d9f62354a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -648,7 +648,7 @@ class DataFrameAggregateSuite extends QueryTest condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", parameters = Map( "functionName" -> "`collect_set`", - "dataType" -> "\"MAP\"", + "dataType" -> "\"MAP\" or \"COLLATED STRING\"", "sqlExpr" -> "\"collect_set(b)\"" ), context = ExpectedContext( From 14c01ebda5a8229f6d9419ede88cb8c71044dd92 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Tue, 15 Oct 2024 15:21:43 +0200 Subject: [PATCH 03/31] [SPARK-49974][SQL] Move resolveRelations(...) out of the Analyzer.scala ### What changes were proposed in this pull request? Move resolveRelation(...) and some of its dependencies out to a separate class, because it's reasonably self-contained. ### Why are the changes needed? Analyzer.scala is 4K+ lines long, so it makes sense to gradually split it. ### Does this PR introduce _any_ user-facing change? No, just the code is moved to a separate class. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? copilot.vim. Closes #48475 from vladimirg-db/vladimirg-db/refactor-resolve-relations. Authored-by: Vladimir Golubev Signed-off-by: Max Gekk --- .../sql/catalyst/analysis/Analyzer.scala | 180 ++----------- .../analysis/RelationResolution.scala | 245 ++++++++++++++++++ 2 files changed, 260 insertions(+), 165 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d41c07b47842..49f3092390536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -203,6 +202,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor with CheckAnalysis with SQLConfHelper with ColumnResolutionHelper { private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog + private val relationResolution = new RelationResolution(catalogManager) override protected def validatePlanChanges( previousPlan: LogicalPlan, @@ -972,30 +972,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty - private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { - AnalysisContext.get.referredTempViewNames.exists { n => - (n.length == nameParts.length) && n.zip(nameParts).forall { - case (a, b) => resolver(a, b) - } - } - } - - // If we are resolving database objects (relations, functions, etc.) insides views, we may need to - // expand single or multi-part identifiers with the current catalog and namespace of when the - // view was created. - private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { - if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts - - if (nameParts.length == 1) { - AnalysisContext.get.catalogAndNamespace :+ nameParts.head - } else if (catalogManager.isCatalogRegistered(nameParts.head)) { - nameParts - } else { - AnalysisContext.get.catalogAndNamespace.head +: nameParts - } - } - /** * Adds metadata columns to output for child relations when nodes are missing resolved attributes. * @@ -1122,7 +1098,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case i @ InsertIntoStatement(table, _, _, _, _, _, _) => val relation = table match { case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).getOrElse(u) + relationResolution.resolveRelation(u).getOrElse(u) case other => other } @@ -1139,7 +1115,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case write: V2WriteCommand => write.table match { case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).map(unwrapRelationPlan).map { + relationResolution.resolveRelation(u).map(unwrapRelationPlan).map { case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( v.desc.identifier, write) case r: DataSourceV2Relation => write.withNewTable(r) @@ -1154,12 +1130,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } case u: UnresolvedRelation => - resolveRelation(u).map(resolveViews).getOrElse(u) + relationResolution.resolveRelation(u).map(resolveViews).getOrElse(u) case r @ RelationTimeTravel(u: UnresolvedRelation, timestamp, version) if timestamp.forall(ts => ts.resolved && !SubqueryExpression.hasSubquery(ts)) => val timeTravelSpec = TimeTravelSpec.create(timestamp, version, conf.sessionLocalTimeZone) - resolveRelation(u, timeTravelSpec).getOrElse(r) + relationResolution.resolveRelation(u, timeTravelSpec).getOrElse(r) case u @ UnresolvedTable(identifier, cmd, suggestAlternative) => lookupTableOrView(identifier).map { @@ -1194,29 +1170,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor }.getOrElse(u) } - private def lookupTempView(identifier: Seq[String]): Option[TemporaryViewRelation] = { - // We are resolving a view and this name is not a temp view when that view was created. We - // return None earlier here. - if (isResolvingView && !isReferredTempViewName(identifier)) return None - v1SessionCatalog.getRawLocalOrGlobalTempView(identifier) - } - - private def resolveTempView( - identifier: Seq[String], - isStreaming: Boolean = false, - isTimeTravel: Boolean = false): Option[LogicalPlan] = { - lookupTempView(identifier).map { v => - val tempViewPlan = v1SessionCatalog.getTempViewRelation(v) - if (isStreaming && !tempViewPlan.isStreaming) { - throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) - } - if (isTimeTravel) { - throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(identifier)) - } - tempViewPlan - } - } - /** * Resolves relations to `ResolvedTable` or `Resolved[Temp/Persistent]View`. This is * for resolving DDL and misc commands. @@ -1224,10 +1177,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def lookupTableOrView( identifier: Seq[String], viewOnly: Boolean = false): Option[LogicalPlan] = { - lookupTempView(identifier).map { tempView => + relationResolution.lookupTempView(identifier).map { tempView => ResolvedTempView(identifier.asIdentifier, tempView.tableMeta) }.orElse { - expandIdentifier(identifier) match { + relationResolution.expandIdentifier(identifier) match { case CatalogAndIdentifier(catalog, ident) => if (viewOnly && !CatalogV2Util.isSessionCatalog(catalog)) { throw QueryCompilationErrors.catalogOperationNotSupported(catalog, "views") @@ -1246,113 +1199,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } } - - private def createRelation( - catalog: CatalogPlugin, - ident: Identifier, - table: Option[Table], - options: CaseInsensitiveStringMap, - isStreaming: Boolean): Option[LogicalPlan] = { - table.map { - // To utilize this code path to execute V1 commands, e.g. INSERT, - // either it must be session catalog, or tracksPartitionsInCatalog - // must be false so it does not require use catalog to manage partitions. - // Obviously we cannot execute V1Table by V1 code path if the table - // is not from session catalog and the table still requires its catalog - // to manage partitions. - case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) - || !v1Table.catalogTable.tracksPartitionsInCatalog => - if (isStreaming) { - if (v1Table.v1Table.tableType == CatalogTableType.VIEW) { - throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError( - ident.quoted) - } - SubqueryAlias( - catalog.name +: ident.asMultipartIdentifier, - UnresolvedCatalogRelation(v1Table.v1Table, options, isStreaming = true)) - } else { - v1SessionCatalog.getRelation(v1Table.v1Table, options) - } - - case table => - if (isStreaming) { - val v1Fallback = table match { - case withFallback: V2TableWithV1Fallback => - Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true)) - case _ => None - } - SubqueryAlias( - catalog.name +: ident.asMultipartIdentifier, - StreamingRelationV2(None, table.name, table, options, table.columns.toAttributes, - Some(catalog), Some(ident), v1Fallback)) - } else { - SubqueryAlias( - catalog.name +: ident.asMultipartIdentifier, - DataSourceV2Relation.create(table, Some(catalog), Some(ident), options)) - } - } - } - - /** - * Resolves relations to v1 relation if it's a v1 table from the session catalog, or to v2 - * relation. This is for resolving DML commands and SELECT queries. - */ - private def resolveRelation( - u: UnresolvedRelation, - timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { - val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( - u.options, - conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), - conf.getConf(SQLConf.TIME_TRAVEL_VERSION_KEY), - conf.sessionLocalTimeZone - ) - if (timeTravelSpec.nonEmpty && timeTravelSpecFromOptions.nonEmpty) { - throw new AnalysisException("MULTIPLE_TIME_TRAVEL_SPEC", Map.empty[String, String]) - } - val finalTimeTravelSpec = timeTravelSpec.orElse(timeTravelSpecFromOptions) - resolveTempView(u.multipartIdentifier, u.isStreaming, finalTimeTravelSpec.isDefined).orElse { - expandIdentifier(u.multipartIdentifier) match { - case CatalogAndIdentifier(catalog, ident) => - val key = - ((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, - finalTimeTravelSpec) - AnalysisContext.get.relationCache.get(key).map { cache => - val cachedRelation = cache.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - } - u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => - val cachedConnectRelation = cachedRelation.clone() - cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) - cachedConnectRelation - }.getOrElse(cachedRelation) - }.orElse { - val writePrivilegesString = - Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) - val table = CatalogV2Util.loadTable( - catalog, ident, finalTimeTravelSpec, writePrivilegesString) - val loaded = createRelation( - catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming) - loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => - loaded.map { loadedRelation => - val loadedConnectRelation = loadedRelation.clone() - loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) - loadedConnectRelation - } - }.getOrElse(loaded) - } - case _ => None - } - } - } - - /** Consumes an unresolved relation and resolves it to a v1 or v2 relation or temporary view. */ - def resolveRelationOrTempView(u: UnresolvedRelation): LogicalPlan = { - EliminateSubqueryAliases(resolveRelation(u).getOrElse(u)) - } } /** Handle INSERT INTO for DSv2 */ @@ -2135,7 +1981,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) { f } else { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(nameParts) val fullName = normalizeFuncName((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq) if (externalFunctionNameSet.contains(fullName)) { @@ -2186,7 +2033,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolvedNonPersistentFunc(nameParts.head, V1Function(info)) } }.getOrElse { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(nameParts) val fullName = catalog.name +: ident.namespace :+ ident.name CatalogV2Util.loadFunction(catalog, ident).map { func => ResolvedPersistentFunc(catalog.asFunctionCatalog, ident, func) @@ -2198,7 +2046,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor withPosition(u) { try { val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { v1SessionCatalog.resolvePersistentTableFunction( ident.asFunctionIdentifier, u.functionArgs) @@ -2355,7 +2204,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private[analysis] def resolveFunction(u: UnresolvedFunction): Expression = { withPosition(u) { resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.nameParts) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(u.nameParts) if (CatalogV2Util.isSessionCatalog(catalog)) { resolveV1Function(ident.asFunctionIdentifier, u.arguments, u) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala new file mode 100644 index 0000000000000..08be456f090e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.catalog.{ + CatalogTableType, + TemporaryViewRelation, + UnresolvedCatalogRelation +} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.connector.catalog.{ + CatalogManager, + CatalogPlugin, + CatalogV2Util, + Identifier, + LookupCatalog, + Table, + V1Table, + V2TableWithV1Fallback +} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ + +class RelationResolution(override val catalogManager: CatalogManager) + extends DataTypeErrorsBase + with Logging + with LookupCatalog + with SQLConfHelper { + val v1SessionCatalog = catalogManager.v1SessionCatalog + + /** + * If we are resolving database objects (relations, functions, etc.) inside views, we may need to + * expand single or multi-part identifiers with the current catalog and namespace of when the + * view was created. + */ + def expandIdentifier(nameParts: Seq[String]): Seq[String] = { + if (!isResolvingView || isReferredTempViewName(nameParts)) { + return nameParts + } + + if (nameParts.length == 1) { + AnalysisContext.get.catalogAndNamespace :+ nameParts.head + } else if (catalogManager.isCatalogRegistered(nameParts.head)) { + nameParts + } else { + AnalysisContext.get.catalogAndNamespace.head +: nameParts + } + } + + /** + * Lookup temporary view by `identifier`. Returns `None` if the view wasn't found. + */ + def lookupTempView(identifier: Seq[String]): Option[TemporaryViewRelation] = { + // We are resolving a view and this name is not a temp view when that view was created. We + // return None earlier here. + if (isResolvingView && !isReferredTempViewName(identifier)) { + return None + } + + v1SessionCatalog.getRawLocalOrGlobalTempView(identifier) + } + + /** + * Resolve relation `u` to v1 relation if it's a v1 table from the session catalog, or to v2 + * relation. This is for resolving DML commands and SELECT queries. + */ + def resolveRelation( + u: UnresolvedRelation, + timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { + val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( + u.options, + conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), + conf.getConf(SQLConf.TIME_TRAVEL_VERSION_KEY), + conf.sessionLocalTimeZone + ) + if (timeTravelSpec.nonEmpty && timeTravelSpecFromOptions.nonEmpty) { + throw new AnalysisException("MULTIPLE_TIME_TRAVEL_SPEC", Map.empty[String, String]) + } + val finalTimeTravelSpec = timeTravelSpec.orElse(timeTravelSpecFromOptions) + resolveTempView( + u.multipartIdentifier, + u.isStreaming, + finalTimeTravelSpec.isDefined + ).orElse { + expandIdentifier(u.multipartIdentifier) match { + case CatalogAndIdentifier(catalog, ident) => + val key = + ( + (catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, + finalTimeTravelSpec + ) + AnalysisContext.get.relationCache + .get(key) + .map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG) + .map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + } + .getOrElse(cachedRelation) + } + .orElse { + val writePrivilegesString = + Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) + val table = + CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, writePrivilegesString) + val loaded = createRelation( + catalog, + ident, + table, + u.clearWritePrivileges.options, + u.isStreaming + ) + loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) + u.getTagValue(LogicalPlan.PLAN_ID_TAG) + .map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + } + .getOrElse(loaded) + } + case _ => None + } + } + } + + private def createRelation( + catalog: CatalogPlugin, + ident: Identifier, + table: Option[Table], + options: CaseInsensitiveStringMap, + isStreaming: Boolean): Option[LogicalPlan] = { + table.map { + // To utilize this code path to execute V1 commands, e.g. INSERT, + // either it must be session catalog, or tracksPartitionsInCatalog + // must be false so it does not require use catalog to manage partitions. + // Obviously we cannot execute V1Table by V1 code path if the table + // is not from session catalog and the table still requires its catalog + // to manage partitions. + case v1Table: V1Table + if CatalogV2Util.isSessionCatalog(catalog) + || !v1Table.catalogTable.tracksPartitionsInCatalog => + if (isStreaming) { + if (v1Table.v1Table.tableType == CatalogTableType.VIEW) { + throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError( + ident.quoted + ) + } + SubqueryAlias( + catalog.name +: ident.asMultipartIdentifier, + UnresolvedCatalogRelation(v1Table.v1Table, options, isStreaming = true) + ) + } else { + v1SessionCatalog.getRelation(v1Table.v1Table, options) + } + + case table => + if (isStreaming) { + val v1Fallback = table match { + case withFallback: V2TableWithV1Fallback => + Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true)) + case _ => None + } + SubqueryAlias( + catalog.name +: ident.asMultipartIdentifier, + StreamingRelationV2( + None, + table.name, + table, + options, + table.columns.toAttributes, + Some(catalog), + Some(ident), + v1Fallback + ) + ) + } else { + SubqueryAlias( + catalog.name +: ident.asMultipartIdentifier, + DataSourceV2Relation.create(table, Some(catalog), Some(ident), options) + ) + } + } + } + + private def resolveTempView( + identifier: Seq[String], + isStreaming: Boolean = false, + isTimeTravel: Boolean = false): Option[LogicalPlan] = { + lookupTempView(identifier).map { v => + val tempViewPlan = v1SessionCatalog.getTempViewRelation(v) + if (isStreaming && !tempViewPlan.isStreaming) { + throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) + } + if (isTimeTravel) { + throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(identifier)) + } + tempViewPlan + } + } + + private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty + + private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { + val resolver = conf.resolver + AnalysisContext.get.referredTempViewNames.exists { n => + (n.length == nameParts.length) && n.zip(nameParts).forall { + case (a, b) => resolver(a, b) + } + } + } +} From f2bd31453885817455993a7eb7dfce42f7ab3ff6 Mon Sep 17 00:00:00 2001 From: RaleSapic Date: Tue, 15 Oct 2024 19:59:47 +0200 Subject: [PATCH 04/31] [SPARK-49916][SQL] Throw appropriate Exception for type mismatch between ColumnType and data type in some rows ### What changes were proposed in this pull request? In this PR, I introduced new exception to be thrown when there is a mismatch between type of the elements in the array column type, and the actual values of that column. Currently, this can happen in Postgres SQL when type of a column is real[] (array of floats), and insert command provides real[][] (array of arrays of floats). In our case, we are letting this to be converted to Spark SQL Types, and fail when trying to read it later. This PR is catching currently thrown internal exception (java.lang.ClassCastException) and re-throw a newly-introduced exception for this problem. ### Why are the changes needed? Throwing better exception that will help Spark user to better understand why the query failed. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing, and newly-added unit tests for Postgres. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48397 from RaleSapic/rastko.sapic@databricks.com/fix-cast-error-in-array-column-types. Lead-authored-by: RaleSapic Co-authored-by: Rastko Sapic Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 6 ++ .../jdbc/v2/PostgresIntegrationSuite.scala | 101 +++++++++++++++++- .../sql/errors/QueryExecutionErrors.scala | 6 ++ .../datasources/jdbc/JdbcUtils.scala | 20 +++- 4 files changed, 126 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8b2a57d6da3dd..8272daadb9159 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -606,6 +606,12 @@ ], "sqlState" : "42711" }, + "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH" : { + "message" : [ + "Some values in field are incompatible with the column array type. Expected type ." + ], + "sqlState" : "0A000" + }, "COLUMN_NOT_DEFINED_IN_TABLE" : { "message" : [ " column is not defined in table , defined table columns are: ." diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 05f02a402353b..f70b500f974a4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkSQLException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -65,9 +65,104 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT |) """.stripMargin ).executeUpdate() - connection.prepareStatement( - "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + + connection.prepareStatement("CREATE TABLE array_test_table (int_array int[]," + + "float_array FLOAT8[], timestamp_array TIMESTAMP[], string_array TEXT[]," + + "datetime_array TIMESTAMPTZ[], array_of_int_arrays INT[][])").executeUpdate() + + val query = + """ + INSERT INTO array_test_table + (int_array, float_array, timestamp_array, string_array, + datetime_array, array_of_int_arrays) + VALUES + ( + ARRAY[1, 2, 3], -- Array of integers + ARRAY[1.1, 2.2, 3.3], -- Array of floats + ARRAY['2023-01-01 12:00'::timestamp, '2023-06-01 08:30'::timestamp], + ARRAY['hello', 'world'], -- Array of strings + ARRAY['2023-10-04 12:00:00+00'::timestamptz, + '2023-12-01 14:15:00+00'::timestamptz], + ARRAY[ARRAY[1, 2]] -- Array of arrays of integers + ), + ( + ARRAY[10, 20, 30], -- Another set of data + ARRAY[10.5, 20.5, 30.5], + ARRAY['2022-01-01 09:15'::timestamp, '2022-03-15 07:45'::timestamp], + ARRAY['postgres', 'arrays'], + ARRAY['2022-11-22 09:00:00+00'::timestamptz, + '2022-12-31 23:59:59+00'::timestamptz], + ARRAY[ARRAY[10, 20]] + ); + """ + connection.prepareStatement(query).executeUpdate() + + connection.prepareStatement("CREATE TABLE array_int (col int[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_bigint(col bigint[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_smallint (col smallint[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_boolean (col boolean[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_float (col real[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_double (col float8[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_timestamp (col timestamp[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_timestamptz (col timestamptz[])") + .executeUpdate() + + connection.prepareStatement("INSERT INTO array_int VALUES (array[array[10]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_bigint VALUES (array[array[10]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_smallint VALUES (array[array[10]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_boolean VALUES (array[array[true]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_float VALUES (array[array[10.5]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_double VALUES (array[array[10.1]])") .executeUpdate() + connection.prepareStatement("INSERT INTO array_timestamp VALUES (" + + "array[array['2022-01-01 09:15'::timestamp]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_timestamptz VALUES " + + "(array[array['2022-01-01 09:15'::timestamptz]])").executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + test("Test multi-dimensional column types") { + // This test is used to verify that the multi-dimensional + // column types are supported by the JDBC V2 data source. + // We do not verify any result output + // + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "array_test_table") + .load() + df.collect() + + val array_tables = Array( + ("array_int", "\"ARRAY\""), + ("array_bigint", "\"ARRAY\""), + ("array_smallint", "\"ARRAY\""), + ("array_boolean", "\"ARRAY\""), + ("array_float", "\"ARRAY\""), + ("array_double", "\"ARRAY\""), + ("array_timestamp", "\"ARRAY\""), + ("array_timestamptz", "\"ARRAY\"") + ) + + array_tables.foreach { case (dbtable, arrayType) => + checkError( + exception = intercept[SparkSQLException] { + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", dbtable) + .load() + df.collect() + }, + condition = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + parameters = Map("pos" -> "0", "type" -> arrayType), + sqlState = Some("0A000") + ) + } } override def dataPreparation(connection: Connection): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 6e64e7e9e39bf..3bc229e9693e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1257,6 +1257,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "dataType" -> toSQLType(dataType))) } + def wrongDatatypeInSomeRows(pos: Int, dataType: DataType): SparkSQLException = { + new SparkSQLException( + errorClass = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + messageParameters = Map("pos" -> pos.toString(), "type" -> toSQLType(dataType))) + } + def rootConverterReturnNullError(): SparkRuntimeException = { new SparkRuntimeException( errorClass = "INVALID_JSON_ROOT_FIELD", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7946068b9452e..6e79a2f2a3267 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -585,14 +585,26 @@ object JdbcUtils extends Logging with SQLConfHelper { arr => new GenericArrayData(elementConversion(et0)(arr)) } + case IntegerType => arrayConverter[Int]((i: Int) => i) + case FloatType => arrayConverter[Float]((f: Float) => f) + case DoubleType => arrayConverter[Double]((d: Double) => d) + case ShortType => arrayConverter[Short]((s: Short) => s) + case BooleanType => arrayConverter[Boolean]((b: Boolean) => b) + case LongType => arrayConverter[Long]((l: Long) => l) + case _ => (array: Object) => array.asInstanceOf[Array[Any]] } (rs: ResultSet, row: InternalRow, pos: Int) => - val array = nullSafeConvert[java.sql.Array]( - input = rs.getArray(pos + 1), - array => new GenericArrayData(elementConversion(et)(array.getArray))) - row.update(pos, array) + try { + val array = nullSafeConvert[java.sql.Array]( + input = rs.getArray(pos + 1), + array => new GenericArrayData(elementConversion(et)(array.getArray()))) + row.update(pos, array) + } catch { + case e: java.lang.ClassCastException => + throw QueryExecutionErrors.wrongDatatypeInSomeRows(pos, dt) + } case NullType => (_: ResultSet, row: InternalRow, pos: Int) => row.update(pos, null) From 5f2bd5c10dc7f7f1ed3d2bd286ad98f284b2032c Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Tue, 15 Oct 2024 20:21:42 +0200 Subject: [PATCH 05/31] [SPARK-49959][SQL] Fix ColumnarArray.copy() to read nulls from the correct offset ### What changes were proposed in this pull request? `ColumnarArray` represents an array containing elements from `data[offset]` to `data[offset + length)`. When copying the array, the null flag should also be read starting from `offset` rather than 0. Some expressions depend on this utility function. For example, this bug can lead to incorrect results in `ArrayTransform`. ### Why are the changes needed? Fix correctness issue. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Two unit tests, one with `ArrayTransform`, and the other tests `ColumnarArray` directly. Both the tests would fail without the change in the PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48458 from chenhao-db/fix_ColumnarArray_copy. Authored-by: Chenhao Li Signed-off-by: Max Gekk --- .../org/apache/spark/sql/vectorized/ColumnarArray.java | 2 +- .../org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 9 +++++++++ .../sql/execution/vectorized/ColumnVectorSuite.scala | 6 ++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 721e6a60befe2..12a2879794b10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -58,7 +58,7 @@ public int numElements() { private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { if (data.hasNull()) { for (int i = 0; i < length; i++) { - if (data.isNullAt(i)) { + if (data.isNullAt(offset + i)) { arrayData.setNullAt(i); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 8c1cc6c3bea1d..48ea0e01a4372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -39,6 +39,15 @@ import org.apache.spark.unsafe.types.CalendarInterval class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { import testImplicits._ + test("ArrayTransform with scan input") { + withTempPath { f => + spark.sql("select array(array(1, null, 3), array(4, 5, null), array(null, 8, 9)) as a") + .write.parquet(f.getAbsolutePath) + val df = spark.read.parquet(f.getAbsolutePath).selectExpr("transform(a, (x, i) -> x)") + checkAnswer(df, Row(Seq(Seq(1, null, 3), Seq(4, 5, null), Seq(null, 8, 9)))) + } + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index aca968745d198..0cc4f7bf2548e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -504,6 +504,12 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { val arr = new ColumnarArray(testVector, 0, testVector.capacity) assert(arr.toSeq(testVector.dataType) == expected) assert(arr.copy().toSeq(testVector.dataType) == expected) + + if (expected.nonEmpty) { + val withOffset = new ColumnarArray(testVector, 1, testVector.capacity - 1) + assert(withOffset.toSeq(testVector.dataType) == expected.tail) + assert(withOffset.copy().toSeq(testVector.dataType) == expected.tail) + } } testVectors("getInts with dictionary and nulls", 3, IntegerType) { testVector => From 1269b35a7e6dec646374e54bb0867e176a2f834a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 15 Oct 2024 20:26:41 +0200 Subject: [PATCH 06/31] [SPARK-49954][SQL] Codegen Support for SchemaOfJson (by Invoke & RuntimeReplaceable) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `schema_of_json`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: JsonFunctionsSuite#`*schema_of_json*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48473 from panbingkun/SPARK-49954_scala. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../sql/catalyst/expressions/ExprUtils.scala | 4 +- .../json/JsonExpressionEvalUtils.scala | 53 +++++++++++++++++++ .../expressions/jsonExpressions.scala | 38 ++++++------- .../function_schema_of_json.explain | 2 +- ...nction_schema_of_json_with_options.explain | 2 +- 5 files changed, 74 insertions(+), 25 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 08cb03edb78b6..38b927f5bbf38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseA import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String -object ExprUtils extends QueryErrorsBase { +object ExprUtils extends EvalHelper with QueryErrorsBase { def evalTypeExpr(exp: Expression): DataType = { if (exp.foldable) { - exp.eval() match { + prepareForEval(exp).eval() match { case s: UTF8String if s != null => val dataType = DataType.parseTypeWithFallback( s.toString, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala new file mode 100644 index 0000000000000..65c95c8240f4f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.json + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +object JsonExpressionEvalUtils { + + def schemaOfJson( + jsonFactory: JsonFactory, + jsonOptions: JSONOptions, + jsonInferSchema: JsonInferSchema, + json: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parser.nextToken() + // To match with schema inference from JSON datasource. + jsonInferSchema.inferField(parser) match { + case st: StructType => + jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) + case at: ArrayType if at.elementType.isInstanceOf[StructType] => + jsonInferSchema + .canonicalizeType(at.elementType, jsonOptions) + .map(ArrayType(_, containsNull = at.containsNull)) + .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) + case other: DataType => + jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( + SQLConf.get.defaultStringType) + } + } + + UTF8String.fromString(dt.sql) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e01531cc821c9..3118fe9a2eb44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json._ @@ -878,7 +878,9 @@ case class StructsToJson( case class SchemaOfJson( child: Expression, options: Map[String, String]) - extends UnaryExpression with CodegenFallback with QueryErrorsBase { + extends UnaryExpression + with RuntimeReplaceable + with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -919,26 +921,20 @@ case class SchemaOfJson( } } - override def eval(v: InternalRow): Any = { - val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => - parser.nextToken() - // To match with schema inference from JSON datasource. - jsonInferSchema.inferField(parser) match { - case st: StructType => - jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) - case at: ArrayType if at.elementType.isInstanceOf[StructType] => - jsonInferSchema - .canonicalizeType(at.elementType, jsonOptions) - .map(ArrayType(_, containsNull = at.containsNull)) - .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) - case other: DataType => - jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( - SQLConf.get.defaultStringType) - } - } + @transient private lazy val jsonFactoryObjectType = ObjectType(classOf[JsonFactory]) + @transient private lazy val jsonOptionsObjectType = ObjectType(classOf[JSONOptions]) + @transient private lazy val jsonInferSchemaObjectType = ObjectType(classOf[JsonInferSchema]) - UTF8String.fromString(dt.sql) - } + override def replacement: Expression = StaticInvoke( + JsonExpressionEvalUtils.getClass, + dataType, + "schemaOfJson", + Seq(Literal(jsonFactory, jsonFactoryObjectType), + Literal(jsonOptions, jsonOptionsObjectType), + Literal(jsonInferSchema, jsonInferSchemaObjectType), + child), + Seq(jsonFactoryObjectType, jsonOptionsObjectType, jsonInferSchemaObjectType, child.dataType) + ) override def prettyName: String = "schema_of_json" diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain index 8ec799bc58084..b400aeeca5af2 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}]) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionEvalUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain index 13867949177a4..b400aeeca5af2 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}], (allowNumericLeadingZeros,true)) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionEvalUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] From 0b488280bef0a59a8b76ff9de728d82663522369 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Tue, 15 Oct 2024 20:34:14 +0200 Subject: [PATCH 07/31] [SPARK-49911][SQL] Fix semantic of support binary equality ### What changes were proposed in this pull request? With introduction of trim collation, what was known as supportsBinaryEquality changes, it is now split in isUtf8BinaryType and usesTrimCollation so that it has correct semantics. ### Why are the changes needed? With introduction of trim collation, what was known as supportsBinaryEquality changes, it is now split in isUtf8BinaryType and usesTrimCollation so that it has correct semantics. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Everything is covered with existing tests, no new functionality is added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48472 from jovanpavl-db/fix-semantic-of-supportBinaryEquality. Authored-by: Jovan Pavlovic Signed-off-by: Max Gekk --- .../util/CollationAwareUTF8String.java | 4 +- .../sql/catalyst/util/CollationFactory.java | 51 ++++---- .../sql/catalyst/util/CollationSupport.java | 122 +++++++++--------- .../unsafe/types/CollationFactorySuite.scala | 8 +- .../sql/catalyst/util/UnsafeRowUtils.scala | 2 +- .../expressions/HashExpressionsSuite.scala | 2 +- .../aggregate/HashMapGenerator.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 12 +- 8 files changed, 106 insertions(+), 99 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index fb610a5d96f17..d67697eaea38b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 01f6c7e0331b0..50bb93465921e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -160,6 +160,18 @@ public static class Collation { */ public final boolean supportsSpaceTrimming; + /** + * Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only + * collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type. + */ + public final boolean isUtf8BinaryType; + + /** + * Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only + * collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type. + */ + public final boolean isUtf8LcaseType; + public Collation( String collationName, String provider, @@ -168,9 +180,8 @@ public Collation( String version, ToLongFunction hashFunction, BiFunction equalsFunction, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality, + boolean isUtf8BinaryType, + boolean isUtf8LcaseType, boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; @@ -178,14 +189,13 @@ public Collation( this.comparator = comparator; this.version = version; this.hashFunction = hashFunction; - this.supportsBinaryEquality = supportsBinaryEquality; - this.supportsBinaryOrdering = supportsBinaryOrdering; - this.supportsLowercaseEquality = supportsLowercaseEquality; + this.isUtf8BinaryType = isUtf8BinaryType; + this.isUtf8LcaseType = isUtf8LcaseType; this.equalsFunction = equalsFunction; this.supportsSpaceTrimming = supportsSpaceTrimming; - - // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality - assert(!supportsBinaryOrdering || supportsBinaryEquality); + this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality assert(!supportsBinaryEquality || !supportsLowercaseEquality); @@ -567,9 +577,8 @@ protected Collation buildCollation() { "1.0", hashFunction, equalsFunction, - /* supportsBinaryEquality = */ true, - /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ true, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } else { Comparator comparator; @@ -595,9 +604,8 @@ protected Collation buildCollation() { "1.0", hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ true, spaceTrimming != SpaceTrimming.NONE); } } @@ -982,9 +990,8 @@ protected Collation buildCollation() { ICU_COLLATOR_VERSION, hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } @@ -1191,9 +1198,9 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input; - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input); } else { CollationKey collationKey = collation.collator.getCollationKey( @@ -1207,9 +1214,9 @@ public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input.getBytes(); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes(); } else { return collation.collator.getCollationKey( diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f05d9e512568f..978b663cc25c9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -37,9 +37,9 @@ public final class CollationSupport { public static class StringSplitSQL { public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(s, d); } else { return execICU(s, d, collationId); @@ -48,9 +48,9 @@ public static UTF8String[] exec(final UTF8String s, final UTF8String d, final in public static String genCode(final String s, final String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", s, d); } else { return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); @@ -71,9 +71,9 @@ public static UTF8String[] execICU(final UTF8String string, final UTF8String del public static class Contains { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -82,9 +82,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -109,9 +109,9 @@ public static class StartsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -120,9 +120,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -146,9 +146,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class EndsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -157,9 +157,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -184,9 +184,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -195,10 +195,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -221,9 +221,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -232,10 +232,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Lower.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -258,9 +258,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -270,10 +270,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -296,7 +296,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); @@ -305,7 +305,7 @@ public static int exec(final UTF8String word, final UTF8String set, final int co public static String genCode(final String word, final String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -324,9 +324,9 @@ public static class StringInstr { public static int exec(final UTF8String string, final UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring); } else { return execICU(string, substring, collationId); @@ -336,9 +336,9 @@ public static String genCode(final String string, final String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); @@ -360,9 +360,9 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(src, search, replace); } else { return execICU(src, search, replace, collationId); @@ -372,9 +372,9 @@ public static String genCode(final String src, final String search, final String final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); @@ -398,9 +398,9 @@ public static class StringLocate { public static int exec(final UTF8String string, final UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring, start); } else { return execICU(string, substring, start, collationId); @@ -410,9 +410,9 @@ public static String genCode(final String string, final String substring, final final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); @@ -436,9 +436,9 @@ public static class SubstringIndex { public static UTF8String exec(final UTF8String string, final UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, delimiter, count); } else { return execICU(string, delimiter, count, collationId); @@ -448,9 +448,9 @@ public static String genCode(final String string, final String delimiter, final String count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", string, delimiter, count); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", string, delimiter, count, collationId); @@ -474,9 +474,9 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(source, dict); } else { return execICU(source, dict, collationId); @@ -503,9 +503,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -520,9 +520,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -559,9 +559,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -576,9 +576,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -614,9 +614,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -631,9 +631,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -669,7 +669,7 @@ public static UTF8String execICU( public static boolean supportsLowercaseRegex(final int collationId) { // for regex, only Unicode case-insensitive matching is possible, // so UTF8_LCASE is treated as UNICODE_CI in this context - return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + return CollationFactory.fetchCollation(collationId).isUtf8LcaseType; } static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index a565d2d347636..df9af1579d4f1 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -38,22 +38,22 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(UTF8_BINARY_COLLATION_ID == 0) val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") - assert(utf8Binary.supportsBinaryEquality) + assert(utf8Binary.isUtf8BinaryType) assert(UTF8_LCASE_COLLATION_ID == 1) val utf8Lcase = fetchCollation(UTF8_LCASE_COLLATION_ID) assert(utf8Lcase.collationName == "UTF8_LCASE") - assert(!utf8Lcase.supportsBinaryEquality) + assert(!utf8Lcase.isUtf8BinaryType) assert(UNICODE_COLLATION_ID == (1 << 29)) val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(!unicode.supportsBinaryEquality) + assert(!unicode.isUtf8BinaryType) assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") - assert(!unicodeCi.supportsBinaryEquality) + assert(!unicodeCi.isUtf8BinaryType) } test("UTF8_BINARY and ICU root locale collation names") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 40b8bccafaad2..118dd92c3ed54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -207,7 +207,7 @@ object UnsafeRowUtils { def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { case st: StringType => val collation = CollationFactory.fetchCollation(st.collationId) - (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) + (!collation.supportsBinaryEquality) case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 6f3890cafd2ac..92ef24bb8ec63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -636,7 +636,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(murmur3Hash1, interpretedHash1) checkEvaluation(murmur3Hash2, interpretedHash2) - if (CollationFactory.fetchCollation(collation).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collation).isUtf8BinaryType) { assert(interpretedHash1 != interpretedHash2) } else { assert(interpretedHash1 == interpretedHash2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 3b1f349520f39..19a36483abe6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,9 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => + case st: StringType if !st.supportsBinaryEquality => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 4234d73c1794d..b6da0b169f050 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1333,7 +1333,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).supportsBinaryEquality + CollationFactory.fetchCollation(collation).isUtf8BinaryType test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1558,7 +1558,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) @@ -1615,7 +1615,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. @@ -1676,7 +1676,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function. @@ -1735,7 +1735,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1794,7 +1794,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) From 856cfe7dc02b91b9e7bec9c4f404f3a2a3d30e60 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Tue, 15 Oct 2024 11:34:38 -0700 Subject: [PATCH 08/31] [SPARK-49969][BUILD] Simplify dependency management in YARN module ### What changes were proposed in this pull request? This PR simplifies dependency management in YARN module by pruning unnecessary test scope dependency which pulls from the vanilla Hadoop client. ### Why are the changes needed? Since 3.2 (SPARK-33212), Spark moved from the vanilla Hadoop3 client to the shaded Hadoop3 client, significantly simplifying dependency management, some hack rules of dependency to address the odd issues can be removed to simplify the Maven/SBT configuration files now. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - pass SBT test: `build/sbt -Pyarn yarn/test` - pass Maven test: `build/mvn -Pyarn -pl :spark-yarn_2.13 clean install -DskipTests -am && build/mvn -Pyarn -pl :spark-yarn_2.13 test` - verified no affection on runtime deps: `dev/test-dependencies.sh` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48468 from pan3793/SPARK-49969. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- project/SparkBuild.scala | 17 +--- resource-managers/yarn/pom.xml | 139 +++++++-------------------------- 2 files changed, 30 insertions(+), 126 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 737efa8f7846b..a87e0af0b542f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1072,20 +1072,9 @@ object DependencyOverrides { object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }, - // SPARK-33705: Due to sbt compiler issues, it brings exclusions defined in maven pom back to - // the classpath directly and assemble test scope artifacts to assembly/target/scala-xx/jars, - // which is also will be added to the classpath of some unit tests that will build a subprocess - // to run `spark-submit`, e.g. HiveThriftServer2Test. - // - // These artifacts are for the jersey-1 API but Spark use jersey-2 ones, so it cause test - // flakiness w/ jar conflicts issues. - // - // Also jersey-1 is only used by yarn module(see resource-managers/yarn/pom.xml) for testing - // purpose only. Here we exclude them from the whole project scope and add them w/ yarn only. excludeDependencies ++= Seq( - ExclusionRule(organization = "com.sun.jersey"), ExclusionRule(organization = "ch.qos.logback"), - ExclusionRule("javax.ws.rs", "jsr311-api")) + ExclusionRule("javax.servlet", "javax.servlet-api")) ) } @@ -1229,10 +1218,6 @@ object YARN { val hadoopProvidedProp = "spark.yarn.isHadoopProvided" lazy val settings = Seq( - excludeDependencies --= Seq( - ExclusionRule(organization = "com.sun.jersey"), - ExclusionRule("javax.servlet", "javax.servlet-api"), - ExclusionRule("javax.ws.rs", "jsr311-api")), Compile / unmanagedResources := (Compile / unmanagedResources).value.filter(!_.getName.endsWith(s"$propFileName")), genConfigProperties := { diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 770a550030f51..5a10aa797c1b1 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -29,43 +29,8 @@ Spark Project YARN yarn - 1.19 - - - hadoop-3 - - true - - - - org.apache.hadoop - hadoop-client-runtime - ${hadoop.version} - ${hadoop.deps.scope} - - - org.apache.hadoop - hadoop-client-minicluster - ${hadoop.version} - test - - - - org.bouncycastle - bcprov-jdk18on - test - - - org.bouncycastle - bcpkix-jdk18on - test - - - - - org.apache.spark @@ -102,6 +67,35 @@ org.apache.hadoop hadoop-client-api ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-client-runtime + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-client-minicluster + ${hadoop.version} + test + + + + javax.xml.bind + jaxb-api + test + + + org.bouncycastle + bcprov-jdk18on + test + + + org.bouncycastle + bcpkix-jdk18on + test @@ -135,22 +129,6 @@ - - - org.eclipse.jetty.orbit - javax.servlet.jsp - 2.2.0.v201112011158 - test - - - org.eclipse.jetty.orbit - javax.servlet.jsp.jstl - 1.2.0.v201105211821 - test - - org.mockito mockito-core @@ -166,65 +144,6 @@ byte-buddy-agent test - - - - com.sun.jersey - jersey-core - test - ${jersey-1.version} - - - com.sun.jersey - jersey-json - test - ${jersey-1.version} - - - com.sun.jersey - jersey-server - test - ${jersey-1.version} - - - com.sun.jersey.contribs - jersey-guice - test - ${jersey-1.version} - - - com.sun.jersey - jersey-servlet - test - ${jersey-1.version} - - - - - ${hive.group} - hive-exec - ${hive.classifier} - provided - - - ${hive.group} - hive-metastore - provided - - - org.apache.thrift - libthrift - provided - - - org.apache.thrift - libfb303 - provided - From 0f6bc3bce9c4083dc603a780f65e64e7973694f0 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 15 Oct 2024 15:19:03 -0700 Subject: [PATCH 09/31] [SPARK-49973][INFRA] Upgrade python to 3.11 for non-python tests ### What changes were proposed in this pull request? Upgrade python to 3.11 for non-python tests: `build` and `buf` ### Why are the changes needed? to be consistent with PySpark tests ### Does this PR introduce _any_ user-facing change? no, infra-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48474 from zhengruifeng/infra_py_311. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- .github/workflows/build_and_test.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 43ac6b50052ae..14d93a498fc59 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -264,20 +264,20 @@ jobs: with: distribution: zulu java-version: ${{ matrix.java }} - - name: Install Python 3.9 + - name: Install Python 3.11 uses: actions/setup-python@v5 # We should install one Python that is higher than 3+ for SQL and Yarn because: # - SQL component also has Python related tests, for example, IntegratedUDFTestUtils. # - Yarn has a Python specific test too, for example, YarnClusterSuite. if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') with: - python-version: '3.9' + python-version: '3.11' architecture: x64 - - name: Install Python packages (Python 3.9) + - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') run: | - python3.9 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1' - python3.9 -m pip list + python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1' + python3.11 -m pip list # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -608,14 +608,14 @@ jobs: with: input: sql/connect/common/src/main against: 'https://github.com/apache/spark.git#branch=branch-3.5,subdir=connector/connect/common/src/main' - - name: Install Python 3.9 + - name: Install Python 3.11 uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.11' - name: Install dependencies for Python CodeGen check run: | - python3.9 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' - python3.9 -m pip list + python3.11 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' + python3.11 -m pip list - name: Python CodeGen check run: ./dev/connect-check-protos.py From 111473dec898f7fb629d4b8d8e0d85e12e04c931 Mon Sep 17 00:00:00 2001 From: HiuFung Kwok Date: Tue, 15 Oct 2024 15:23:24 -0700 Subject: [PATCH 10/31] [SPARK-49964][BUILD] Remove `ws-rs-api` package ### What changes were proposed in this pull request? - To Remove the dependency of `javax.ws.rs.ws-rs-api` as it's no longer required. Prior discussion can be found on: - https://github.com/apache/spark/pull/41340 - https://github.com/apache/spark/pull/45154 ### Why are the changes needed? In the past, the codebase used to have a few .scala classes referencing and using the `ws-rs-api`, such as https://github.com/apache/spark/commit/b7fdc23ccc5967de5799d8cf6f14289e71f29a1e#diff-9c5fb3d1b7e3b0f54bc5c4182965c4fe1f9023d449017cece3005d3f90e8e4d8R624-R627 However as the time passed by, all usages of `ws-rs-api` are either got removed / refactored. Hence there is no need to have it import on root POM as now and we can always re-introduce it later, if the usage can be justified again. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit-test, to make sure the codebase is not impacted by the removal of the dependency. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48461 from hiufung-kwok/ft-hf-SPARK-49963-remove-ws-rs-api. Authored-by: HiuFung Kwok Signed-off-by: Dongjoon Hyun --- pom.xml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pom.xml b/pom.xml index 2b89454873782..cab7f7f595434 100644 --- a/pom.xml +++ b/pom.xml @@ -1115,11 +1115,6 @@ jersey-client ${jersey.version} - - javax.ws.rs - javax.ws.rs-api - 2.0.1 - javax.xml.bind jaxb-api From add4a9c6bd5d507dac99b2206cc940ffc01aac1a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 15 Oct 2024 15:25:42 -0700 Subject: [PATCH 11/31] [SPARK-49922][BUILD] Upgrade `sbt-assembly` to `2.3.0` ### What changes were proposed in this pull request? The pr aims to upgrade `sbt-assembly` from `2.2.0` to `2.3.0` ### Why are the changes needed? - `sbt-assembly`, the full release notes: https://github.com/sbt/sbt-assembly/releases/tag/v2.3.0 - Bug fixed: Fixes assembly not creating parent directories by Roiocam in https://github.com/sbt/sbt-assembly/pull/525 Throws error when a misconfigured assemblyOutputPath is detected by hygt in https://github.com/sbt/sbt-assembly/pull/523 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48409 from panbingkun/SPARK-49922. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- project/plugins.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 67d739452d8da..b2d0177e6a411 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -24,7 +24,7 @@ libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "10.17.0" // checkstyle uses guava 33.1.0-jre. libraryDependencies += "com.google.guava" % "guava" % "33.1.0-jre" -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.2.0") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.0") addSbtPlugin("com.github.sbt" % "sbt-eclipse" % "6.2.0") From 0e75d19a736aa18fe77414991ebb7e3577a43af8 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 15 Oct 2024 15:28:31 -0700 Subject: [PATCH 12/31] [SPARK-49792][PYTHON][BUILD] Upgrade to numpy 2 for building and testing Spark branches ### What changes were proposed in this pull request? Upgrade numpy to 2.1.0 for building and testing Spark branches. Failed tests are categorized into the following groups: - Most of test failures fixed are related to https://github.com/pandas-dev/pandas/issues/59838#event-14332587978. - Replaced np.mat with np.asmatrix. - TODO: SPARK-49793 ### Why are the changes needed? Ensure compatibility with newer NumPy, which is utilized by Pandas (on Spark). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48180 from xinrong-meng/np_upgrade. Authored-by: Xinrong Meng Signed-off-by: Dongjoon Hyun --- dev/infra/Dockerfile | 6 +- python/pyspark/ml/classification.py | 8 +-- python/pyspark/ml/regression.py | 10 +-- python/pyspark/ml/tests/test_functions.py | 5 ++ python/pyspark/ml/tuning.py | 2 +- python/pyspark/mllib/classification.py | 25 +++---- python/pyspark/mllib/feature.py | 4 +- python/pyspark/mllib/random.py | 42 ++++++------ python/pyspark/mllib/regression.py | 80 +++++++++++------------ python/pyspark/pandas/generic.py | 8 +-- python/pyspark/pandas/indexing.py | 8 +-- python/pyspark/pandas/series.py | 8 +-- 12 files changed, 107 insertions(+), 99 deletions(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 1619b009e9364..10a39497c8ed9 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20241002 +ENV FULL_REFRESH_DATE 20241007 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -91,10 +91,10 @@ RUN mkdir -p /usr/local/pypy/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml +RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 937753b50bb13..b89755d9c18a5 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -699,7 +699,7 @@ class LinearSVC( >>> model_path = temp_path + "/svm_model" >>> model.save(model_path) >>> model2 = LinearSVCModel.load(model_path) - >>> model.coefficients[0] == model2.coefficients[0] + >>> bool(model.coefficients[0] == model2.coefficients[0]) True >>> model.intercept == model2.intercept True @@ -1210,7 +1210,7 @@ class LogisticRegression( >>> model_path = temp_path + "/lr_model" >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) - >>> blorModel.coefficients[0] == model2.coefficients[0] + >>> bool(blorModel.coefficients[0] == model2.coefficients[0]) True >>> blorModel.intercept == model2.intercept True @@ -2038,9 +2038,9 @@ class RandomForestClassifier( >>> result = model.transform(test0).head() >>> result.prediction 0.0 - >>> numpy.argmax(result.probability) + >>> int(numpy.argmax(result.probability)) 0 - >>> numpy.argmax(result.newRawPrediction) + >>> int(numpy.argmax(result.newRawPrediction)) 0 >>> result.leafId DenseVector([0.0, 0.0, 0.0]) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index d08e241b41d23..d7cc27e274279 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -266,7 +266,7 @@ class LinearRegression( True >>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001 True - >>> abs(model.coefficients[0] - 1.0) < 0.001 + >>> bool(abs(model.coefficients[0] - 1.0) < 0.001) True >>> abs(model.intercept - 0.0) < 0.001 True @@ -283,11 +283,11 @@ class LinearRegression( >>> model_path = temp_path + "/lr_model" >>> model.save(model_path) >>> model2 = LinearRegressionModel.load(model_path) - >>> model.coefficients[0] == model2.coefficients[0] + >>> bool(model.coefficients[0] == model2.coefficients[0]) True - >>> model.intercept == model2.intercept + >>> bool(model.intercept == model2.intercept) True - >>> model.transform(test0).take(1) == model2.transform(test0).take(1) + >>> bool(model.transform(test0).take(1) == model2.transform(test0).take(1)) True >>> model.numFeatures 1 @@ -2542,7 +2542,7 @@ class GeneralizedLinearRegression( >>> model2 = GeneralizedLinearRegressionModel.load(model_path) >>> model.intercept == model2.intercept True - >>> model.coefficients[0] == model2.coefficients[0] + >>> bool(model.coefficients[0] == model2.coefficients[0]) True >>> model.transform(df).take(1) == model2.transform(df).take(1) True diff --git a/python/pyspark/ml/tests/test_functions.py b/python/pyspark/ml/tests/test_functions.py index 7df0a26394140..e67e46ded67bd 100644 --- a/python/pyspark/ml/tests/test_functions.py +++ b/python/pyspark/ml/tests/test_functions.py @@ -18,6 +18,7 @@ import numpy as np +from pyspark.loose_version import LooseVersion from pyspark.ml.functions import predict_batch_udf from pyspark.sql.functions import array, struct, col from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, StructField, FloatType @@ -193,6 +194,10 @@ def predict(inputs): batch_sizes = preds["preds"].to_numpy() self.assertTrue(all(batch_sizes <= batch_size)) + # TODO(SPARK-49793): enable the test below + @unittest.skipIf( + LooseVersion(np.__version__) >= LooseVersion("2"), "Caching does not work with numpy 2" + ) def test_caching(self): def make_predict_fn(): # emulate loading a model, this should only be invoked once (per worker process) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index e8713d81c4d62..888beff663523 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -706,7 +706,7 @@ class CrossValidator( >>> cvModel = cv.fit(dataset) >>> cvModel.getNumFolds() 3 - >>> cvModel.avgMetrics[0] + >>> float(cvModel.avgMetrics[0]) 0.5 >>> path = tempfile.mkdtemp() >>> model_path = path + "/model" diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 1e1795d9fb3d4..bf8fd04dc2837 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -172,9 +172,9 @@ class LogisticRegressionModel(LinearClassificationModel): >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LogisticRegressionModel.load(sc, path) - >>> sameModel.predict(numpy.array([0.0, 1.0])) + >>> int(sameModel.predict(numpy.array([0.0, 1.0]))) 1 - >>> sameModel.predict(SparseVector(2, {0: 1.0})) + >>> int(sameModel.predict(SparseVector(2, {0: 1.0}))) 0 >>> from shutil import rmtree >>> try: @@ -555,7 +555,7 @@ class SVMModel(LinearClassificationModel): >>> svm.predict(sc.parallelize([[1.0]])).collect() [1] >>> svm.clearThreshold() - >>> svm.predict(numpy.array([1.0])) + >>> float(svm.predict(numpy.array([1.0]))) 1.44... >>> sparse_data = [ @@ -573,9 +573,9 @@ class SVMModel(LinearClassificationModel): >>> path = tempfile.mkdtemp() >>> svm.save(sc, path) >>> sameModel = SVMModel.load(sc, path) - >>> sameModel.predict(SparseVector(2, {1: 1.0})) + >>> int(sameModel.predict(SparseVector(2, {1: 1.0}))) 1 - >>> sameModel.predict(SparseVector(2, {0: -1.0})) + >>> int(sameModel.predict(SparseVector(2, {0: -1.0}))) 0 >>> from shutil import rmtree >>> try: @@ -756,11 +756,11 @@ class NaiveBayesModel(Saveable, Loader["NaiveBayesModel"]): ... LabeledPoint(1.0, [1.0, 0.0]), ... ] >>> model = NaiveBayes.train(sc.parallelize(data)) - >>> model.predict(numpy.array([0.0, 1.0])) + >>> float(model.predict(numpy.array([0.0, 1.0]))) 0.0 - >>> model.predict(numpy.array([1.0, 0.0])) + >>> float(model.predict(numpy.array([1.0, 0.0]))) 1.0 - >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect() + >>> list(map(float, model.predict(sc.parallelize([[1.0, 0.0]])).collect())) [1.0] >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {1: 0.0})), @@ -768,15 +768,18 @@ class NaiveBayesModel(Saveable, Loader["NaiveBayesModel"]): ... LabeledPoint(1.0, SparseVector(2, {0: 1.0})) ... ] >>> model = NaiveBayes.train(sc.parallelize(sparse_data)) - >>> model.predict(SparseVector(2, {1: 1.0})) + >>> float(model.predict(SparseVector(2, {1: 1.0}))) 0.0 - >>> model.predict(SparseVector(2, {0: 1.0})) + >>> float(model.predict(SparseVector(2, {0: 1.0}))) 1.0 >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = NaiveBayesModel.load(sc, path) - >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) + >>> bool(( + ... sameModel.predict(SparseVector(2, {0: 1.0})) == + ... model.predict(SparseVector(2, {0: 1.0})) + ... )) True >>> from shutil import rmtree >>> try: diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 24884f4853371..915a55595cb53 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -554,9 +554,9 @@ class PCA: ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] >>> model = PCA(2).fit(sc.parallelize(data)) >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray() - >>> pcArray[0] + >>> float(pcArray[0]) 1.648... - >>> pcArray[1] + >>> float(pcArray[1]) -4.013... """ diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 80bbd717071dc..dbe1048a64b36 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -134,9 +134,9 @@ def normalRDD( >>> stats = x.stats() >>> stats.count() 1000 - >>> abs(stats.mean() - 0.0) < 0.1 + >>> bool(abs(stats.mean() - 0.0) < 0.1) True - >>> abs(stats.stdev() - 1.0) < 0.1 + >>> bool(abs(stats.stdev() - 1.0) < 0.1) True """ return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @@ -186,10 +186,10 @@ def logNormalRDD( >>> stats = x.stats() >>> stats.count() 1000 - >>> abs(stats.mean() - expMean) < 0.5 + >>> bool(abs(stats.mean() - expMean) < 0.5) True >>> from math import sqrt - >>> abs(stats.stdev() - expStd) < 0.5 + >>> bool(abs(stats.stdev() - expStd) < 0.5) True """ return callMLlibFunc( @@ -238,7 +238,7 @@ def poissonRDD( >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + >>> bool(abs(stats.stdev() - sqrt(mean)) < 0.5) True """ return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @@ -285,7 +285,7 @@ def exponentialRDD( >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + >>> bool(abs(stats.stdev() - sqrt(mean)) < 0.5) True """ return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @@ -336,9 +336,9 @@ def gammaRDD( >>> stats = x.stats() >>> stats.count() 1000 - >>> abs(stats.mean() - expMean) < 0.5 + >>> bool(abs(stats.mean() - expMean) < 0.5) True - >>> abs(stats.stdev() - expStd) < 0.5 + >>> bool(abs(stats.stdev() - expStd) < 0.5) True """ return callMLlibFunc( @@ -384,7 +384,7 @@ def uniformVectorRDD( >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape (10, 10) - >>> mat.max() <= 1.0 and mat.min() >= 0.0 + >>> bool(mat.max() <= 1.0 and mat.min() >= 0.0) True >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 @@ -430,9 +430,9 @@ def normalVectorRDD( >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - 0.0) < 0.1 + >>> bool(abs(mat.mean() - 0.0) < 0.1) True - >>> abs(mat.std() - 1.0) < 0.1 + >>> bool(abs(mat.std() - 1.0) < 0.1) True """ return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) @@ -488,9 +488,9 @@ def logNormalVectorRDD( >>> mat = np.matrix(m) >>> mat.shape (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 + >>> bool(abs(mat.mean() - expMean) < 0.1) True - >>> abs(mat.std() - expStd) < 0.1 + >>> bool(abs(mat.std() - expStd) < 0.1) True """ return callMLlibFunc( @@ -545,13 +545,13 @@ def poissonVectorRDD( >>> import numpy as np >>> mean = 100.0 >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) - >>> mat = np.mat(rdd.collect()) + >>> mat = np.asmatrix(rdd.collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - mean) < 0.5 + >>> bool(abs(mat.mean() - mean) < 0.5) True >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 + >>> bool(abs(mat.std() - sqrt(mean)) < 0.5) True """ return callMLlibFunc( @@ -599,13 +599,13 @@ def exponentialVectorRDD( >>> import numpy as np >>> mean = 0.5 >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) - >>> mat = np.mat(rdd.collect()) + >>> mat = np.asmatrix(rdd.collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - mean) < 0.5 + >>> bool(abs(mat.mean() - mean) < 0.5) True >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 + >>> bool(abs(mat.std() - sqrt(mean)) < 0.5) True """ return callMLlibFunc( @@ -662,9 +662,9 @@ def gammaVectorRDD( >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 + >>> bool(abs(mat.mean() - expMean) < 0.1) True - >>> abs(mat.std() - expStd) < 0.1 + >>> bool(abs(mat.std() - expStd) < 0.1) True """ return callMLlibFunc( diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index f1003327912d0..87f05bc0979b8 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -144,9 +144,9 @@ class LinearRegressionModelBase(LinearModel): -------- >>> from pyspark.mllib.linalg import SparseVector >>> lrmb = LinearRegressionModelBase(np.array([1.0, 2.0]), 0.1) - >>> abs(lrmb.predict(np.array([-1.03, 7.777])) - 14.624) < 1e-6 + >>> bool(abs(lrmb.predict(np.array([-1.03, 7.777])) - 14.624) < 1e-6) True - >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 + >>> bool(abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6) True """ @@ -190,23 +190,23 @@ class LinearRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(lrm.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True - >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + >>> bool(abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5) True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LinearRegressionModel.load(sc, path) - >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(sameModel.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(sameModel.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> from shutil import rmtree >>> try: @@ -221,16 +221,16 @@ class LinearRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... miniBatchFraction=1.0, initialWeights=np.array([1.0]), regParam=0.1, regType="l2", ... intercept=True, validateData=True) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True """ @@ -402,23 +402,23 @@ class LassoModel(LinearRegressionModelBase): ... ] >>> lrm = LassoWithSGD.train( ... sc.parallelize(data), iterations=10, initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(lrm.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True - >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + >>> bool(abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5) True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LassoModel.load(sc, path) - >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(sameModel.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(sameModel.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> from shutil import rmtree >>> try: @@ -433,16 +433,16 @@ class LassoModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=np.array([1.0]), intercept=True, ... validateData=True) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True """ @@ -580,23 +580,23 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(lrm.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True - >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + >>> bool(abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5) True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = RidgeRegressionModel.load(sc, path) - >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(sameModel.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(sameModel.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> from shutil import rmtree >>> try: @@ -611,16 +611,16 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=np.array([1.0]), intercept=True, ... validateData=True) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True """ @@ -764,19 +764,19 @@ class IsotonicRegressionModel(Saveable, Loader["IsotonicRegressionModel"]): -------- >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) - >>> irm.predict(3) + >>> float(irm.predict(3)) 2.0 - >>> irm.predict(5) + >>> float(irm.predict(5)) 16.5 - >>> irm.predict(sc.parallelize([3, 5])).collect() + >>> list(map(float, irm.predict(sc.parallelize([3, 5])).collect())) [2.0, 16.5] >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> irm.save(sc, path) >>> sameModel = IsotonicRegressionModel.load(sc, path) - >>> sameModel.predict(3) + >>> float(sameModel.predict(3)) 2.0 - >>> sameModel.predict(5) + >>> float(sameModel.predict(5)) 16.5 >>> from shutil import rmtree >>> try: diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 6e63cff1d37b9..55f15fd2eb1a2 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -2631,7 +2631,7 @@ def first_valid_index(self) -> Optional[Union[Scalar, Tuple[Scalar, ...]]]: 500 5.0 dtype: float64 - >>> s.first_valid_index() + >>> int(s.first_valid_index()) 300 Support for MultiIndex @@ -2950,7 +2950,7 @@ def get(self, key: Any, default: Optional[Any] = None) -> Any: 20 1 b 20 2 b - >>> df.x.get(10) + >>> int(df.x.get(10)) 0 >>> df.x.get(20) @@ -3008,7 +3008,7 @@ def squeeze(self, axis: Optional[Axis] = None) -> Union[Scalar, "DataFrame", "Se 0 2 dtype: int64 - >>> even_primes.squeeze() + >>> int(even_primes.squeeze()) 2 Squeezing objects with more than one value in every axis does nothing: @@ -3066,7 +3066,7 @@ def squeeze(self, axis: Optional[Axis] = None) -> Union[Scalar, "DataFrame", "Se Squeezing all axes will project directly into a scalar: - >>> df_1a.squeeze() + >>> int(df_1a.squeeze()) 3 """ if axis is not None: diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index b5bf65a4907b7..c93366a31e315 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -122,7 +122,7 @@ class AtIndexer(IndexerLike): Get value at specified row/column pair - >>> psdf.at[4, 'B'] + >>> int(psdf.at[4, 'B']) 2 Get array if an index occurs multiple times @@ -202,7 +202,7 @@ class iAtIndexer(IndexerLike): Get value at specified row/column pair - >>> df.iat[1, 2] + >>> int(df.iat[1, 2]) 1 Get value within a series @@ -214,7 +214,7 @@ class iAtIndexer(IndexerLike): 30 3 dtype: int64 - >>> psser.iat[1] + >>> int(psser.iat[1]) 2 """ @@ -853,7 +853,7 @@ class LocIndexer(LocIndexerLike): Single label for column. - >>> df.loc['cobra', 'shield'] + >>> int(df.loc['cobra', 'shield']) 2 List of labels for row. diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index ff941b692f95f..7e276860fbab1 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -4558,7 +4558,7 @@ def pop(self, item: Name) -> Union["Series", Scalar]: C 2 dtype: int64 - >>> s.pop('A') + >>> int(s.pop('A')) 0 >>> s @@ -5821,7 +5821,7 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: A scalar `where`. - >>> s.asof(20) + >>> float(s.asof(20)) 2.0 For a sequence `where`, a Series is returned. The first value is @@ -5836,12 +5836,12 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: Missing values are not considered. The following is ``2.0``, not NaN, even though NaN is at the index location for ``30``. - >>> s.asof(30) + >>> float(s.asof(30)) 2.0 >>> s = ps.Series([1, 2, np.nan, 4], index=[10, 30, 20, 40]) >>> with ps.option_context("compute.eager_check", False): - ... s.asof(20) + ... float(s.asof(20)) ... 1.0 """ From 861b5e98e6e4f61e376d756f085e0290e01fc8f4 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 16 Oct 2024 08:49:10 +0800 Subject: [PATCH 13/31] [SPARK-49948][PS][CONNECT] Add parameter "precision" to pandas on Spark box plot ### What changes were proposed in this pull request? Add parameter "precision" to pandas on Spark box plot. ### Why are the changes needed? Previously, the box method used **kwds, allowing precision to be passed implicitly. Now, adding precision directly to the signature ensures clarity and explicit control, improving usability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48445 from xinrong-meng/ps_box. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- python/pyspark/pandas/plot/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 7333fae1ad432..12c17a06f153b 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -841,7 +841,7 @@ def barh(self, x=None, y=None, **kwargs): elif isinstance(self.data, DataFrame): return self(kind="barh", x=x, y=y, **kwargs) - def box(self, **kwds): + def box(self, precision=0.01, **kwds): """ Make a box plot of the DataFrame columns. @@ -857,14 +857,13 @@ def box(self, **kwds): Parameters ---------- - **kwds : optional - Additional keyword arguments are documented in - :meth:`pyspark.pandas.Series.plot`. - precision: scalar, default = 0.01 This argument is used by pandas-on-Spark to compute approximate statistics for building a boxplot. Use *smaller* values to get more precise - statistics (matplotlib-only). + statistics. + **kwds : optional + Additional keyword arguments are documented in + :meth:`pyspark.pandas.Series.plot`. Returns ------- @@ -902,7 +901,7 @@ def box(self, **kwds): from pyspark.pandas import DataFrame, Series if isinstance(self.data, (Series, DataFrame)): - return self(kind="box", **kwds) + return self(kind="box", precision=precision, **kwds) def hist(self, bins=10, **kwds): """ From f3b2535d8d92c2210501f15c5845dd589414ffe3 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Wed, 16 Oct 2024 09:04:58 +0200 Subject: [PATCH 14/31] [SPARK-49970][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_2069 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for `_LEGACY_ERROR_TEMP_2069` ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48469 from itholic/LEGACY_2069. Lead-authored-by: Haejoon Lee Co-authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 11 ++++++----- .../spark/sql/errors/QueryExecutionErrors.scala | 2 +- .../datasources/v2/V2SessionCatalogSuite.scala | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8272daadb9159..d9880899347a3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -376,6 +376,12 @@ ], "sqlState" : "429BB" }, + "CANNOT_REMOVE_RESERVED_PROPERTY" : { + "message" : [ + "Cannot remove reserved property: ." + ], + "sqlState" : "42000" + }, "CANNOT_RENAME_ACROSS_SCHEMA" : { "message" : [ "Renaming a across schemas is not allowed." @@ -6955,11 +6961,6 @@ "Missing database location." ] }, - "_LEGACY_ERROR_TEMP_2069" : { - "message" : [ - "Cannot remove reserved property: ." - ] - }, "_LEGACY_ERROR_TEMP_2070" : { "message" : [ "Writing job failed." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 3bc229e9693e9..43fc0b567dcc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -870,7 +870,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def cannotRemoveReservedPropertyError(property: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2069", + errorClass = "CANNOT_REMOVE_RESERVED_PROPERTY", messageParameters = Map("property" -> property)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index c88f51a6b7d06..8091d6e64fdc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -1173,7 +1173,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { exception = intercept[SparkUnsupportedOperationException] { catalog.alterNamespace(testNs, NamespaceChange.removeProperty(p)) }, - condition = "_LEGACY_ERROR_TEMP_2069", + condition = "CANNOT_REMOVE_RESERVED_PROPERTY", parameters = Map("property" -> p)) } From 2a1301133138ba0d5e2d969fc6428153903ffff1 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 16 Oct 2024 09:10:06 +0200 Subject: [PATCH 15/31] [SPARK-49966][SQL] Codegen Support for JsonToStructs(`from_json`) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `JsonToStructs`(`from_json`). ### Why are the changes needed? - improve codegen coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: JsonFunctionsSuite#`*from_json*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48466 from panbingkun/SPARK-49966. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../json/JsonExpressionEvalUtils.scala | 64 ++++++++++++++++- .../expressions/jsonExpressions.scala | 72 +++++++------------ 2 files changed, 88 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index 65c95c8240f4f..6291e62304a38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -18,9 +18,14 @@ package org.apache.spark.sql.catalyst.expressions.json import com.fasterxml.jackson.core.JsonFactory -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -51,3 +56,58 @@ object JsonExpressionEvalUtils { UTF8String.fromString(dt.sql) } } + +class JsonToStructsEvaluator( + options: Map[String, String], + nullableSchema: DataType, + nameOfCorruptRecord: String, + timeZoneId: Option[String], + variantAllowDuplicateKeys: Boolean) extends Serializable { + + // This converts parsed rows to the desired output by the given schema. + @transient + private lazy val converter = nullableSchema match { + case _: StructType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null + case _: ArrayType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null + case _: MapType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null + } + + @transient + private lazy val parser = { + val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord) + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode) + } + val (parserSchema, actualSchema) = nullableSchema match { + case s: StructType => + ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) + (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) + case other => + (StructType(Array(StructField("value", other))), other) + } + + val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) + val createParser = CreateJacksonParser.utf8String _ + + new FailureSafeParser[UTF8String]( + input => rawParser.parse(input, createParser, identity[UTF8String]), + mode, + parserSchema, + parsedOptions.columnNameOfCorruptRecord) + } + + final def evaluate(json: UTF8String): Any = { + if (json == null) return null + nullableSchema match { + case _: VariantType => + VariantExpressionEvalUtils.parseJson(json, + allowDuplicateKeys = variantAllowDuplicateKeys) + case _ => + converter(parser.parse(json)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 3118fe9a2eb44..6eef3d6f9d7df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern} import org.apache.spark.sql.catalyst.util._ @@ -639,7 +638,6 @@ case class JsonToStructs( variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)) extends UnaryExpression with TimeZoneAwareExpression - with CodegenFallback with ExpectsInputTypes with NullIntolerant with QueryErrorsBase { @@ -647,7 +645,7 @@ case class JsonToStructs( // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. - val nullableSchema = schema.asNullable + private val nullableSchema: DataType = schema.asNullable override def nullable: Boolean = true @@ -680,53 +678,35 @@ case class JsonToStructs( messageParameters = Map("schema" -> toSQLType(nullableSchema))) } - // This converts parsed rows to the desired output by the given schema. - @transient - lazy val converter = nullableSchema match { - case _: StructType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null - case _: ArrayType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null - case _: MapType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null - } - - val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) - @transient lazy val parser = { - val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord) - val mode = parsedOptions.parseMode - if (mode != PermissiveMode && mode != FailFastMode) { - throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode) - } - val (parserSchema, actualSchema) = nullableSchema match { - case s: StructType => - ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) - (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) - case other => - (StructType(Array(StructField("value", other))), other) - } - - val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) - val createParser = CreateJacksonParser.utf8String _ - - new FailureSafeParser[UTF8String]( - input => rawParser.parse(input, createParser, identity[UTF8String]), - mode, - parserSchema, - parsedOptions.columnNameOfCorruptRecord) - } - override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(json: Any): Any = nullableSchema match { - case _: VariantType => - VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String], - allowDuplicateKeys = variantAllowDuplicateKeys) - case _ => - converter(parser.parse(json.asInstanceOf[UTF8String])) + @transient + private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + + @transient + private lazy val evaluator = new JsonToStructsEvaluator( + options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys) + + override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String]) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) + val eval = child.genCode(ctx) + val resultType = CodeGenerator.boxedType(dataType) + val resultTerm = ctx.freshName("result") + ev.copy(code = + code""" + |${eval.code} + |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value}); + |boolean ${ev.isNull} = $resultTerm == null; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + |""".stripMargin) } override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil From 39112e4f2f8c1401ffa73c84398d3b8f0afa211a Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 16 Oct 2024 09:21:48 +0200 Subject: [PATCH 16/31] [SPARK-49946][CORE] Require an error class in `SparkOutOfMemoryError` ### What changes were proposed in this pull request? In the PR, I propose to remove the constructors that accept a plan string as an error message, and leave only constructors with the error classes. ### Why are the changes needed? To migrate the code which uses `SparkOutOfMemoryError` on new error framework. ### Does this PR introduce _any_ user-facing change? No, it shouldn't because the exception is supposed to raised by Spark. ### How was this patch tested? By running the modified test suites: ``` $ build/sbt "core/testOnly *ExecutorSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48442 from MaxGekk/req-error-cond-SparkOutOfMemoryError. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../main/resources/error/error-conditions.json | 15 +++++++++++++++ .../spark/memory/SparkOutOfMemoryError.java | 8 -------- .../apache/spark/memory/TaskMemoryManager.java | 16 +++++++--------- .../unsafe/sort/UnsafeInMemorySorter.java | 3 ++- .../apache/spark/executor/ExecutorSuite.scala | 10 ++++++++-- .../spark/sql/errors/QueryExecutionErrors.scala | 3 ++- .../execution/aggregate/HashAggregateExec.scala | 2 +- .../aggregate/TungstenAggregationIterator.scala | 4 +++- 8 files changed, 38 insertions(+), 23 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index d9880899347a3..502558c21faa9 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -8711,6 +8711,21 @@ "Doesn't support month or year interval: " ] }, + "_LEGACY_ERROR_TEMP_3300" : { + "message" : [ + "error while calling spill() on : " + ] + }, + "_LEGACY_ERROR_TEMP_3301" : { + "message" : [ + "Not enough memory to grow pointer array" + ] + }, + "_LEGACY_ERROR_TEMP_3302" : { + "message" : [ + "No enough memory for aggregation" + ] + }, "_LEGACY_ERROR_USER_RAISED_EXCEPTION" : { "message" : [ "" diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index fa71eb066ff89..0e35ebecfd270 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -32,14 +32,6 @@ public final class SparkOutOfMemoryError extends OutOfMemoryError implements Spa String errorClass; Map messageParameters; - public SparkOutOfMemoryError(String s) { - super(s); - } - - public SparkOutOfMemoryError(OutOfMemoryError e) { - super(e.getMessage()); - } - public SparkOutOfMemoryError(String errorClass, Map messageParameters) { super(SparkThrowableHelper.getMessage(errorClass, messageParameters)); this.errorClass = errorClass; diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index df224bc902bff..bd9f58bf7415f 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -21,13 +21,7 @@ import java.io.InterruptedIOException; import java.io.IOException; import java.nio.channels.ClosedByInterruptException; -import java.util.Arrays; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; +import java.util.*; import com.google.common.annotations.VisibleForTesting; @@ -291,8 +285,12 @@ private long trySpillAndAcquire( logger.error("error while calling spill() on {}", e, MDC.of(LogKeys.MEMORY_CONSUMER$.MODULE$, consumerToSpill)); // checkstyle.off: RegexpSinglelineJava - throw new SparkOutOfMemoryError("error while calling spill() on " + consumerToSpill + " : " - + e.getMessage()); + throw new SparkOutOfMemoryError( + "_LEGACY_ERROR_TEMP_3300", + new HashMap() {{ + put("consumerToSpill", consumerToSpill.toString()); + put("message", e.getMessage()); + }}); // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 7579c0aefb250..761ced66f78cf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -18,6 +18,7 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import java.util.HashMap; import java.util.LinkedList; import javax.annotation.Nullable; @@ -215,7 +216,7 @@ public void expandPointerArray(LongArray newArray) { if (array != null) { if (newArray.size() < array.size()) { // checkstyle.off: RegexpSinglelineJava - throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3301", new HashMap()); // checkstyle.on: RegexpSinglelineJava } Platform.copyMemory( diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 805e7ca467497..fa13092dc47aa 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.lang.Thread.UncaughtExceptionHandler import java.net.URL import java.nio.ByteBuffer -import java.util.Properties +import java.util.{HashMap, Properties} import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean @@ -522,7 +522,13 @@ class ExecutorSuite extends SparkFunSuite testThrowable(new OutOfMemoryError(), depthToCheck, isFatal = true) testThrowable(new InterruptedException(), depthToCheck, isFatal = false) testThrowable(new RuntimeException("test"), depthToCheck, isFatal = false) - testThrowable(new SparkOutOfMemoryError("test"), depthToCheck, isFatal = false) + testThrowable( + new SparkOutOfMemoryError( + "_LEGACY_ERROR_USER_RAISED_EXCEPTION", + new HashMap[String, String]() { + put("errorMessage", "test") + }), + depthToCheck, isFatal = false) } // Verify we can handle the cycle in the exception chain diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 43fc0b567dcc2..ebcc98a3af27a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1112,7 +1112,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { new SparkOutOfMemoryError( - "_LEGACY_ERROR_TEMP_2107") + "_LEGACY_ERROR_TEMP_2107", + new java.util.HashMap[String, String]()) } def rowLargerThan256MUnsupportedError(): SparkUnsupportedOperationException = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f2b7ca5cba25..750b74aab384f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -682,7 +682,7 @@ case class HashAggregateExec( | $unsafeRowKeys, $unsafeRowKeyHash); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); + | throw new $oomeClassName("_LEGACY_ERROR_TEMP_3302", new java.util.HashMap()); | } |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1ebf0d143bd1f..2f1cda9d0f9be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import java.util + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.SparkOutOfMemoryError @@ -210,7 +212,7 @@ class TungstenAggregationIterator( if (buffer == null) { // failed to allocate the first page // scalastyle:off throwerror - throw new SparkOutOfMemoryError("No enough memory for aggregation") + throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3302", new util.HashMap()) // scalastyle:on throwerror } } From 8d1cb76066e45bf24952124d3edc4357303067e5 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 16 Oct 2024 14:57:31 +0200 Subject: [PATCH 17/31] [SPARK-49987][SQL] Fix the error prompt when `seedExpression` is non-foldable in `randstr` ### What changes were proposed in this pull request? The pr aims to - fix the `error prompt` when `seedExpression` is `non-foldable` in `randstr`. - use `toSQLId` to set the parameter value `inputName` for `randstr ` and `uniform` of `NON_FOLDABLE_INPUT`. ### Why are the changes needed? - Let me take an example ```scala val df = Seq(1.1).toDF("a") df.createOrReplaceTempView("t") sql("SELECT randstr(1, a) from t").show(false) ``` - Before image ```shell [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "randstr(1, a)" due to data type mismatch: the input seedExpression should be a foldable INT or SMALLINT expression; however, got "a". SQLSTATE: 42K09; line 1 pos 7; 'Project [unresolvedalias(randstr(1, a#5, false))] +- SubqueryAlias t +- View (`t`, [a#5]) +- Project [value#1 AS a#5] +- LocalRelation [value#1] ``` - After ```shell [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "randstr(1, a)" due to data type mismatch: the input seed should be a foldable INT or SMALLINT expression; however, got "a". SQLSTATE: 42K09; line 1 pos 7; 'Project [unresolvedalias(randstr(1, a#5, false))] +- SubqueryAlias t +- View (`t`, [a#5]) +- Project [value#1 AS a#5] +- LocalRelation [value#1] ``` - The `parameter` name (`seedExpression`) in the error message does not match the `parameter` name (`seed`) seen in docs by the end-user. image ### Does this PR introduce _any_ user-facing change? Yes, When `seed` is `non-foldable `, the end-user will get a consistent experience in the error prompt. ### How was this patch tested? Update existed UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48490 from panbingkun/SPARK-49987. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../sql/catalyst/expressions/randomExpressions.scala | 8 ++++---- .../resources/sql-tests/analyzer-results/random.sql.out | 8 ++++---- .../src/test/resources/sql-tests/results/random.sql.out | 8 ++++---- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 3cec83facd01d..16bdaa1f7f708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} +import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} @@ -263,7 +263,7 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { @@ -374,14 +374,14 @@ case class RandStr( var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "INT or SMALLINT" Seq((length, "length", 0), - (seedExpression, "seedExpression", 1)).foreach { + (seedExpression, "seed", 1)).foreach { case (expr: Expression, name: String, index: Int) => if (result == TypeCheckResult.TypeCheckSuccess) { if (!expr.foldable) { result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 133cd6a60a4fb..31919381c99b6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -188,7 +188,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -211,7 +211,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -436,7 +436,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -459,7 +459,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 0b4e5e078ee15..01638abdcec6e 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -240,7 +240,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -265,7 +265,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -520,7 +520,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -545,7 +545,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 47691e1ccd40f..39c839ae5a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -478,7 +478,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "length", + "inputName" -> "`length`", "inputType" -> "INT or SMALLINT", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"randstr(a, 10)\""), @@ -530,7 +530,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "min", + "inputName" -> "`min`", "inputType" -> "integer or floating-point", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"uniform(a, 10)\""), From a3b91247b32083805fdd50e9f7f46e9a91b8fd8d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 16 Oct 2024 07:22:59 -0700 Subject: [PATCH 18/31] [SPARK-49981][CORE][TESTS] Fix `AsyncRDDActionsSuite.FutureAction result, timeout` test case to be robust ### What changes were proposed in this pull request? This PR aims to fix `AsyncRDDActionsSuite.FutureAction result, timeout` test case to be robust. ### Why are the changes needed? To reduce the flakiness in GitHub Action CI. Previously, the sleep time is identical to the timeout time. It causes a flakiness in some environments like GitHub Action. - https://github.com/apache/spark/actions/runs/11298639789/job/31428018075 ``` AsyncRDDActionsSuite: ... - FutureAction result, timeout *** FAILED *** Expected exception java.util.concurrent.TimeoutException to be thrown, but no exception was thrown (AsyncRDDActionsSuite.scala:206) ``` ### Does this PR introduce _any_ user-facing change? No, this is a test-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48485 from dongjoon-hyun/SPARK-49981. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 4239180ba6c37..fb2bb83cb7fc4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -201,10 +201,10 @@ class AsyncRDDActionsSuite extends SparkFunSuite with TimeLimits { test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) - .mapPartitions(itr => { Thread.sleep(20); itr }) + .mapPartitions(itr => { Thread.sleep(200); itr }) .countAsync() intercept[TimeoutException] { - ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) + ThreadUtils.awaitResult(f, Duration(2, "milliseconds")) } } From bcfe62b9988f9b00c23de0b71acc1c6170edee9e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 16 Oct 2024 07:24:33 -0700 Subject: [PATCH 19/31] [SPARK-49983][CORE][TESTS] Fix `BarrierTaskContextSuite.successively sync with allGather and barrier` test case to be robust ### What changes were proposed in this pull request? This PR aims to fix `BarrierTaskContextSuite.successively sync with allGather and barrier` test case to be robust. ### Why are the changes needed? The test case asserts the duration of partitions. However, this is flaky because we don't know when a partition is triggered before `barrier` sync. https://github.com/apache/spark/blob/0e75d19a736aa18fe77414991ebb7e3577a43af8/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala#L116-L118 Although we added `TestUtils.waitUntilExecutorsUp` at Apache Spark 3.0.0 like the following, - #28658 let's say a partition starts slowly than `38ms` and all partitions sleep `1s` exactly. Then, the test case fails like the following. - https://github.com/apache/spark/actions/runs/11298639789/job/31428018075 ``` BarrierTaskContextSuite: ... - successively sync with allGather and barrier *** FAILED *** 1038 was not less than or equal to 1000 (BarrierTaskContextSuite.scala:118) ``` According to the failure history here (SPARK-49983) and SPARK-31730, the slowness seems to be less than `200ms` when it happens. So, this PR aims to reduce the flakiness by capping the sleep up to 500ms while keeping the `1s` validation. There is no test coverage change because this test case focuses on the `successively sync with allGather and battier`. ### Does this PR introduce _any_ user-facing change? No, this is a test-only test case. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48487 from dongjoon-hyun/SPARK-49983. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/scheduler/BarrierTaskContextSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 849832c57edaa..f00fb0d2cfa3f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -101,7 +101,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) + Thread.sleep(Random.nextInt(500)) context.barrier() val time1 = System.currentTimeMillis() // Sleep for a random time before global sync. From 60200ae195a124003cf77d4ab3872f1652b6b9c7 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 16 Oct 2024 18:48:51 +0200 Subject: [PATCH 20/31] [SPARK-49957][SQL] Scala API for string validation functions ### What changes were proposed in this pull request? Adding the Scala API for the 4 new string validation expressions: - is_valid_utf8 - make_valid_utf8 - validate_utf8 - try_validate_utf8 ### Why are the changes needed? Offer a complete Scala API for the new expressions in Spark 4.0. ### Does this PR introduce _any_ user-facing change? Yes, adding Scala API for the 4 new Spark expressions. ### How was this patch tested? New tests for the Scala API. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48454 from uros-db/api-validation. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- python/pyspark/sql/tests/test_functions.py | 4 +- .../org/apache/spark/sql/functions.scala | 38 +++++++++++++++++++ .../spark/sql/StringFunctionsSuite.scala | 38 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a51156e895c62..f6c1278c0dc7a 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -83,7 +83,9 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set() + expected_missing_in_py = set( + ["is_valid_utf8", "make_valid_utf8", "validate_utf8", "try_validate_utf8"] + ) self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 4838bc5298bb3..4a9a20efd3a56 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -3911,6 +3911,44 @@ object functions { def encode(value: Column, charset: String): Column = Column.fn("encode", value, lit(charset)) + /** + * Returns true if the input is a valid UTF-8 string, otherwise returns false. + * + * @group string_funcs + * @since 4.0.0 + */ + def is_valid_utf8(str: Column): Column = + Column.fn("is_valid_utf8", str) + + /** + * Returns a new string in which all invalid UTF-8 byte sequences, if any, are replaced by the + * Unicode replacement character (U+FFFD). + * + * @group string_funcs + * @since 4.0.0 + */ + def make_valid_utf8(str: Column): Column = + Column.fn("make_valid_utf8", str) + + /** + * Returns the input value if it corresponds to a valid UTF-8 string, or emits a + * SparkIllegalArgumentException exception otherwise. + * + * @group string_funcs + * @since 4.0.0 + */ + def validate_utf8(str: Column): Column = + Column.fn("validate_utf8", str) + + /** + * Returns the input value if it corresponds to a valid UTF-8 string, or NULL otherwise. + * + * @group string_funcs + * @since 4.0.0 + */ + def try_validate_utf8(str: Column): Column = + Column.fn("try_validate_utf8", str) + /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with * HALF_EVEN round mode, and returns the result as a string column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ec240d71b851f..c94f57a11426a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -352,6 +352,44 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { // scalastyle:on } + test("UTF-8 string is valid") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(is_valid_utf8($"a")), Row(true)) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(is_valid_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(is_valid_utf8($"a")), Row(false)) + // scalastyle:on + } + + test("UTF-8 string make valid") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(make_valid_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(make_valid_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(make_valid_utf8($"a")), Row("\uFFFD")) + // scalastyle:on + } + + test("UTF-8 string validate") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(validate_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(validate_utf8($"b")), Row(null)) + checkError( + exception = intercept[SparkIllegalArgumentException] { + Seq(Array[Byte](-1)).toDF("a").select(validate_utf8($"a")).collect() + }, + condition = "INVALID_UTF8_STRING", + parameters = Map("str" -> "\\xFF") + ) + // scalastyle:on + } + + test("UTF-8 string try validate") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(try_validate_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(try_validate_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(try_validate_utf8($"a")), Row(null)) + // scalastyle:on + } + test("string translate") { val df = Seq(("translate", "")).toDF("a", "b") checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) From f860af67db34c9ae68076a867d4d61caf574cbb8 Mon Sep 17 00:00:00 2001 From: zml1206 Date: Thu, 17 Oct 2024 01:23:16 +0800 Subject: [PATCH 21/31] [SPARK-48155][FOLLOWUP][SQL] AQEPropagateEmptyRelation for left anti join should check if remain child is just BroadcastQueryStageExec ### What changes were proposed in this pull request? As title. ### Why are the changes needed? We encountered BroadcastNestedLoopJoin LeftAnti BuildLeft, and it's right is empty. It is left child of left outer BroadcastHashJoin. The case is more complicated, part of the Initial Plan is as follows ``` :- Project (214) : +- BroadcastHashJoin LeftOuter BuildRight (213) : :- BroadcastNestedLoopJoin LeftAnti BuildLeft (211) : : :- BroadcastExchange (187) : : : +- Project (186) : : : +- Filter (185) : : : +- Scan parquet (31) : : +- LocalLimit (210) : : +- Project (209) : : +- BroadcastHashJoin Inner BuildLeft (208) : : :- BroadcastExchange (194) : : : +- Project (193) : : : +- BroadcastHashJoin LeftOuter BuildRight (192) : : : :- Project (189) : : : : +- Filter (188) : : : : +- Scan parquet (37) : : : +- BroadcastExchange (191) ``` After AQEPropagateEmptyRelation, report an error "HashJoin should not take LeftOuter as the JoinType with building left side" ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48300 from zml1206/SPARK-48155-followup. Authored-by: zml1206 Signed-off-by: Wenchen Fan --- .../optimizer/PropagateEmptyRelation.scala | 3 +- .../adaptive/AdaptiveQueryExecSuite.scala | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 832af340c3397..d23d43acc217b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -111,7 +111,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) case LeftSemi if isRightEmpty | isFalseCondition => empty(p) - case LeftAnti if isRightEmpty | isFalseCondition => p.left + case LeftAnti if (isRightEmpty | isFalseCondition) && canExecuteWithoutJoin(p.left) => + p.left case FullOuter if isLeftEmpty && isRightEmpty => empty(p) case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c5e64c96b2c8a..4bf993f82495b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2829,6 +2829,38 @@ class AdaptiveQueryExecSuite assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) assert(findTopLevelUnion(adaptivePlan).size == 0) } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF().createOrReplaceTempView("t1") + spark.range(100).createOrReplaceTempView("t2") + spark.range(2).createOrReplaceTempView("t3") + spark.range(2).createOrReplaceTempView("t4") + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT tt2.value + |FROM ( + | SELECT value + | FROM t1 + | WHERE NOT EXISTS ( + | SELECT 1 + | FROM ( + | SELECT t2.id + | FROM t2 + | JOIN t3 ON t2.id = t3.id + | AND t2.id > 100 + | ) tt + | WHERE t1.value = tt.id + | ) + | AND t1.value = 1 + |) tt2 + | LEFT JOIN t4 ON tt2.value = t4.id + |""".stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + } + } } test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { From 31a411773a3e97adb833289f8c695b37802cfedb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 16 Oct 2024 11:27:35 -0700 Subject: [PATCH 22/31] [SPARK-49057][SQL][TESTS][FOLLOWUP] Handle `_LEGACY_ERROR_TEMP_2235` error case ### What changes were proposed in this pull request? This PR aims to fix a flaky test by handling `_LEGACY_ERROR_TEMP_2235`(multiple failures exception) in addition to the single exception. ### Why are the changes needed? After merging - #47533 The following failures were reported multiple times in the PR and today. - https://github.com/apache/spark/actions/runs/11358629880/job/31593568476 - https://github.com/apache/spark/actions/runs/11367718498/job/31621128680 - https://github.com/apache/spark/actions/runs/11360602982/job/31598792247 ``` [info] - SPARK-47148: AQE should avoid to submit shuffle job on cancellation *** FAILED *** (6 seconds, 92 milliseconds) [info] "Multiple failures in stage materialization." did not contain "coalesce test error" (AdaptiveQueryExecSuite.scala:939) ``` The root cause is that `AdaptiveSparkPlanExec.cleanUpAndThrowException` throws two types of exceptions. When there are multiple errors, `_LEGACY_ERROR_TEMP_2235` is thrown. We need to handle this too in the test case. https://github.com/apache/spark/blob/bcfe62b9988f9b00c23de0b71acc1c6170edee9e/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala#L843-L850 https://github.com/apache/spark/blob/bcfe62b9988f9b00c23de0b71acc1c6170edee9e/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala#L1916-L1921 ### Does this PR introduce _any_ user-facing change? No, this is a test-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48498 from dongjoon-hyun/SPARK-49057. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 4bf993f82495b..8e9ba6c8e21d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -936,7 +936,8 @@ class AdaptiveQueryExecSuite val error = intercept[SparkException] { joined.collect() } - assert(error.getMessage() contains "coalesce test error") + assert((Seq(error) ++ Option(error.getCause) ++ error.getSuppressed()).exists( + e => e.getMessage() != null && e.getMessage().contains("coalesce test error"))) val adaptivePlan = joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] From f5e6b05e486efb3d67fd06166ca8f103efb750dc Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Wed, 16 Oct 2024 23:24:39 +0200 Subject: [PATCH 23/31] [SPARK-49643][SQL] Merge _LEGACY_ERROR_TEMP_2042 into ARITHMETIC_OVERFLOW ### What changes were proposed in this pull request? Merging related legacy error to its proper class. ### Why are the changes needed? We want to get remove legacy errors, as they are not properly migrated to the new system of errors. Also, [PR](https://github.com/apache/spark/pull/48206/files#diff-0ffd087e0d4e1618761a42c91b8712fd469e758f4789ca2fafdefff753fe81d5) started getting to big, so this is an effort to split the change needed. ### Does this PR introduce _any_ user-facing change? Yes, legacy error is now merged into ARITHMETIC_OVERFLOW. ### How was this patch tested? Existing tests check that the error message stayed the same. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48496 from mihailom-db/error2042. Authored-by: Mihailo Milosevic Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 5 ----- .../sql/catalyst/expressions/intervalExpressions.scala | 2 +- .../apache/spark/sql/errors/QueryExecutionErrors.scala | 10 ---------- .../expressions/IntervalExpressionsSuite.scala | 2 +- 4 files changed, 2 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 502558c21faa9..fdc00549cc088 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6864,11 +6864,6 @@ " is not implemented." ] }, - "_LEGACY_ERROR_TEMP_2042" : { - "message" : [ - ". If necessary set to false to bypass this error." - ] - }, "_LEGACY_ERROR_TEMP_2045" : { "message" : [ "Unsupported table change: " diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 13676733a9bad..d18630f542020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -336,7 +336,7 @@ case class MakeInterval( val iu = IntervalUtils.getClass.getName.stripSuffix("$") val secFrac = sec.getOrElse("0") val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.arithmeticOverflowError(e);" + """throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", null);""" } else { s"${ev.isNull} = true;" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ebcc98a3af27a..edc1b909292df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -599,16 +599,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("methodName" -> methodName)) } - def arithmeticOverflowError(e: ArithmeticException): SparkArithmeticException = { - new SparkArithmeticException( - errorClass = "_LEGACY_ERROR_TEMP_2042", - messageParameters = Map( - "message" -> e.getMessage, - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = Array.empty, - summary = "") - } - def binaryArithmeticCauseOverflowError( eval1: Short, symbol: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 7caf23490a0ce..78bc77b9dc2ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -266,7 +266,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks), Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, Decimal.MAX_LONG_DIGITS, 6))) - checkExceptionInExpression[ArithmeticException](intervalExpr, EmptyRow, "") + checkExceptionInExpression[ArithmeticException](intervalExpr, EmptyRow, "ARITHMETIC_OVERFLOW") } withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { From e92bf3746ebf52028e2bc2168583bf9e1f463434 Mon Sep 17 00:00:00 2001 From: Tinglong Liao Date: Thu, 17 Oct 2024 08:13:07 +0900 Subject: [PATCH 24/31] [SPARK-49978][R] Move sparkR deprecation warning to package attach time ### What changes were proposed in this pull request? Previously, the output deprecation warning happens in the `spark.session` function, in this PR, we move it to the `.onAttach` function so it will be triggered whenever library is attached ### Why are the changes needed? I believe having the warning message on attach time have the following benefits: - **Have a more prompt warning.** If the deprecation is for the whole package instead of just the `sparkR.session` function, it is more intuitive for the warning to show up on attach time instead of waiting til later time - **Do not rely on the assumption of "every sparkR user will run sparkR.session method".** This asumption may not hold true all the time. For example, some hosted spark platform like Databricks already configure the spark session in the background and therefore will not show the error message. So making this change should make sure a broader reach for this warning notification - **Less intrusive warning**. Previous warning show up every time `sparkR.session` is called, but the new warning message will only show up once even if user run multiple `library`/`require` commands ### Does this PR introduce _any_ user-facing change? **Yes** 1. No more waring message in sparkR.session method 2. Warning message on library attach (when calling `library`/`require` function) image 3. Able to surpress warning by setting `SPARKR_SUPPRESS_DEPRECATION_WARNING` image ### How was this patch tested? Just a simple migration change, will rely on existing pre/post-merge check, and this existing test Also did manual testing(see previous section for screenshot) ### Was this patch authored or co-authored using generative AI tooling? No Closes #48482 from tinglongliao-db/sparkR-deprecation-migration. Authored-by: Tinglong Liao Signed-off-by: Hyukjin Kwon --- R/pkg/DESCRIPTION | 1 + R/pkg/R/sparkR.R | 6 ------ R/pkg/R/zzz.R | 30 ++++++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) create mode 100644 R/pkg/R/zzz.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f7dd261c10fd2..49000c62d1063 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -57,6 +57,7 @@ Collate: 'types.R' 'utils.R' 'window.R' + 'zzz.R' RoxygenNote: 7.1.2 VignetteBuilder: knitr NeedsCompilation: no diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 29c05b0db7c2d..1b5faad376eaa 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -403,12 +403,6 @@ sparkR.session <- function( sparkPackages = "", enableHiveSupport = TRUE, ...) { - - if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { - warning( - "SparkR is deprecated from Apache Spark 4.0.0 and will be removed in a future version.") - } - sparkConfigMap <- convertNamedListToEnv(sparkConfig) namedParams <- list(...) if (length(namedParams) > 0) { diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..947bd543b75e0 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# zzz.R - package startup message + +.onAttach <- function(...) { + if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { + packageStartupMessage( + paste0( + "Warning: ", + "SparkR is deprecated in Apache Spark 4.0.0 and will be removed in a future release. ", + "To continue using Spark in R, we recommend using sparklyr instead: ", + "https://spark.posit.co/get-started/" + ) + ) + } +} From 224d3ba1a2cde664fb94a96a4af1defac9ea401c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 10:37:05 +0900 Subject: [PATCH 25/31] [SPARK-49986][INFRA] Restore `scipy` installation in dockerfile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Restore `scipy` installation in dockerfile ### Why are the changes needed? https://docs.scipy.org/doc/scipy-1.13.1/building/index.html#system-level-dependencies > If you want to use the system Python and pip, you will need: C, C++, and Fortran compilers (typically gcc, g++, and gfortran). ... `scipy` actually depends on `gfortran`, but `apt-get remove --purge -y 'gfortran-11'` broke this dependency. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually check with the first commit https://github.com/apache/spark/pull/48489/commits/5be0dfa2431653c00c430424867dcc3918078226: move `apt-get remove --purge -y 'gfortran-11'` ahead of `scipy` installation, then the installation fails with ``` #18 394.3 Collecting scipy #18 394.4 Downloading scipy-1.13.1.tar.gz (57.2 MB) #18 395.2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.2/57.2 MB 76.7 MB/s eta 0:00:00 #18 401.3 Installing build dependencies: started #18 410.5 Installing build dependencies: finished with status 'done' #18 410.5 Getting requirements to build wheel: started #18 410.7 Getting requirements to build wheel: finished with status 'done' #18 410.7 Installing backend dependencies: started #18 411.8 Installing backend dependencies: finished with status 'done' #18 411.8 Preparing metadata (pyproject.toml): started #18 414.9 Preparing metadata (pyproject.toml): finished with status 'error' #18 414.9 error: subprocess-exited-with-error #18 414.9 #18 414.9 × Preparing metadata (pyproject.toml) did not run successfully. #18 414.9 │ exit code: 1 #18 414.9 ╰─> [42 lines of output] #18 414.9 + meson setup /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d/.mesonpy-xqfvs4ek -Dbuildtype=release -Db_ndebug=if-release -Db_vscrt=md --native-file=/tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d/.mesonpy-xqfvs4ek/meson-python-native-file.ini #18 414.9 The Meson build system #18 414.9 Version: 1.5.2 #18 414.9 Source dir: /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d #18 414.9 Build dir: /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984415533ae9079d/.mesonpy-xqfvs4ek #18 414.9 Build type: native build #18 414.9 Project name: scipy #18 414.9 Project version: 1.13.1 #18 414.9 C compiler for the host machine: cc (gcc 11.4.0 "cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0") #18 414.9 C linker for the host machine: cc ld.bfd 2.38 #18 414.9 C++ compiler for the host machine: c++ (gcc 11.4.0 "c++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0") #18 414.9 C++ linker for the host machine: c++ ld.bfd 2.38 #18 414.9 Cython compiler for the host machine: cython (cython 3.0.11) #18 414.9 Host machine cpu family: x86_64 #18 414.9 Host machine cpu: x86_64 #18 414.9 Program python found: YES (/usr/local/bin/pypy3) #18 414.9 Run-time dependency python found: YES 3.9 #18 414.9 Program cython found: YES (/tmp/pip-build-env-v_vnvt3h/overlay/bin/cython) #18 414.9 Compiler for C supports arguments -Wno-unused-but-set-variable: YES #18 414.9 Compiler for C supports arguments -Wno-unused-function: YES #18 414.9 Compiler for C supports arguments -Wno-conversion: YES #18 414.9 Compiler for C supports arguments -Wno-misleading-indentation: YES #18 414.9 Library m found: YES #18 414.9 #18 414.9 ../meson.build:78:0: ERROR: Unknown compiler(s): [['gfortran'], ['flang'], ['nvfortran'], ['pgfortran'], ['ifort'], ['ifx'], ['g95']] #18 414.9 The following exception(s) were encountered: #18 414.9 Running `gfortran --version` gave "[Errno 2] No such file or directory: 'gfortran'" #18 414.9 Running `gfortran -V` gave "[Errno 2] No such file or directory: 'gfortran'" #18 414.9 Running `flang --version` gave "[Errno 2] No such file or directory: 'flang'" #18 414.9 Running `flang -V` gave "[Errno 2] No such file or directory: 'flang'" #18 414.9 Running `nvfortran --version` gave "[Errno 2] No such file or directory: 'nvfortran'" #18 414.9 Running `nvfortran -V` gave "[Errno 2] No such file or directory: 'nvfortran'" #18 414.9 Running `pgfortran --version` gave "[Errno 2] No such file or directory: 'pgfortran'" #18 414.9 Running `pgfortran -V` gave "[Errno 2] No such file or directory: 'pgfortran'" #18 414.9 Running `ifort --version` gave "[Errno 2] No such file or directory: 'ifort'" #18 414.9 Running `ifort -V` gave "[Errno 2] No such file or directory: 'ifort'" #18 414.9 Running `ifx --version` gave "[Errno 2] No such file or directory: 'ifx'" #18 414.9 Running `ifx -V` gave "[Errno 2] No such file or directory: 'ifx'" #18 414.9 Running `g95 --version` gave "[Errno 2] No such file or directory: 'g95'" #18 414.9 Running `g95 -V` gave "[Errno 2] No such file or directory: 'g95'" #18 414.9 #18 414.9 A full log can be found at /tmp/pip-install-y77ar9d0/scipy_1e543e0816ed4b26984[4155](https://github.com/zhengruifeng/spark/actions/runs/11357130578/job/31589506939#step:7:4161)33ae9079d/.mesonpy-xqfvs4ek/meson-logs/meson-log.txt #18 414.9 [end of output] ``` see https://github.com/zhengruifeng/spark/actions/runs/11357130578/job/31589506939 ### Was this patch authored or co-authored using generative AI tooling? no Closes #48489 from zhengruifeng/infra_scipy. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- dev/infra/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 10a39497c8ed9..1edeed775880b 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -152,6 +152,6 @@ RUN python3.13 -m pip install lxml numpy>=2.1 && \ python3.13 -m pip cache purge # Remove unused installation packages to free up disk space -RUN apt-get remove --purge -y 'gfortran-11' 'humanity-icon-theme' 'nodejs-doc' || true +RUN apt-get remove --purge -y 'humanity-icon-theme' 'nodejs-doc' RUN apt-get autoremove --purge -y RUN apt-get clean From baa5f408a0985d703b4a1e4c5490c77b239180c4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 17 Oct 2024 11:29:00 +0900 Subject: [PATCH 26/31] [SPARK-49945][PS][CONNECT] Add alias for `distributed_id` ### What changes were proposed in this pull request? 1, make `registerInternalExpression` support alias; 2, add alias `distributed_id` for `MonotonicallyIncreasingID` (rename `distributed_index` to `distributed_id` to be more consistent with existing `distributed_sequence_id`); 3, remove `distributedIndex` from `PythonSQLUtils` ### Why are the changes needed? make PS on Connect more consistent with Classic: ```py In [9]: ps.set_option("compute.default_index_type", "distributed") In [10]: spark_frame = ps.range(10).to_spark() In [11]: InternalFrame.attach_default_index(spark_frame).explain(True) ``` before: ![image](https://github.com/user-attachments/assets/6ce1fb5f-a3c6-42d5-a21e-3925207cb4d0) ``` == Parsed Logical Plan == 'Project ['monotonically_increasing_id() AS __index_level_0__#27, 'id] +- 'Project ['id] +- Project [__index_level_0__#19L, id#16L, monotonically_increasing_id() AS __natural_order__#22L] +- Project [monotonically_increasing_id() AS __index_level_0__#19L, id#16L] +- Range (0, 10, step=1, splits=Some(12)) ... ``` after: ![image](https://github.com/user-attachments/assets/00d3a8a1-251c-4cee-851e-c10f294d5248) ``` == Parsed Logical Plan == 'Project ['distributed_id() AS __index_level_0__#65, *] +- 'Project ['id] +- Project [__index_level_0__#45L, id#42L, monotonically_increasing_id() AS __natural_order__#48L] +- Project [distributed_id() AS __index_level_0__#45L, id#42L] +- Range (0, 10, step=1, splits=Some(12)) ... ``` ### Does this PR introduce _any_ user-facing change? spark ui ### How was this patch tested? existing test and manually check ### Was this patch authored or co-authored using generative AI tooling? no Closes #48439 from zhengruifeng/distributed_index. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/internal.py | 9 +-------- python/pyspark/pandas/spark/functions.py | 4 ++++ .../catalyst/analysis/FunctionRegistry.scala | 18 +++++++++++++++--- .../spark/sql/api/python/PythonSQLUtils.scala | 6 ------ 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 6063641e22e3b..90c361547b814 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -909,14 +909,7 @@ def attach_sequence_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDa @staticmethod def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: - scols = [scol_for(sdf, column) for column in sdf.columns] - # Does not add an alias to avoid having some changes in protobuf definition for now. - # The alias is more for query strings in DataFrame.explain, and they are cosmetic changes. - if is_remote(): - return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols) - jvm = sdf.sparkSession._jvm - jcol = jvm.PythonSQLUtils.distributedIndex() - return sdf.select(PySparkColumn(jcol).alias(column_name), *scols) + return sdf.select(SF.distributed_id().alias(column_name), "*") @staticmethod def attach_distributed_sequence_column( diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index bdd11559df3b6..53146a163b1ef 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -79,6 +79,10 @@ def null_index(col: Column) -> Column: return _invoke_internal_function_over_columns("null_index", col) +def distributed_id() -> Column: + return _invoke_internal_function_over_columns("distributed_id") + + def distributed_sequence_id() -> Column: return _invoke_internal_function_over_columns("distributed_sequence_id") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d03d8114e9976..abe61619a2331 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -895,9 +895,20 @@ object FunctionRegistry { /** Registry for internal functions used by Connect and the Column API. */ private[sql] val internal: SimpleFunctionRegistry = new SimpleFunctionRegistry - private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = { - val (info, builder) = FunctionRegistryBase.build(name, None) - internal.internalRegisterFunction(FunctionIdentifier(name), info, builder) + private def registerInternalExpression[T <: Expression : ClassTag]( + name: String, + setAlias: Boolean = false): Unit = { + val (info, builder) = FunctionRegistryBase.build[T](name, None) + val newBuilder = if (setAlias) { + (expressions: Seq[Expression]) => { + val expr = builder(expressions) + expr.setTagValue(FUNC_ALIAS, name) + expr + } + } else { + builder + } + internal.internalRegisterFunction(FunctionIdentifier(name), info, newBuilder) } registerInternalExpression[Product]("product") @@ -911,6 +922,7 @@ object FunctionRegistry { registerInternalExpression[Days]("days") registerInternalExpression[Hours]("hours") registerInternalExpression[UnwrapUDT]("unwrap_udt") + registerInternalExpression[MonotonicallyIncreasingID]("distributed_id", setAlias = true) registerInternalExpression[DistributedSequenceID]("distributed_sequence_id") registerInternalExpression[PandasProduct]("pandas_product") registerInternalExpression[PandasStddev]("pandas_stddev") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 08395ef4c347c..a66a6e54a7c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -154,12 +154,6 @@ private[sql] object PythonSQLUtils extends Logging { def namedArgumentExpression(name: String, e: Column): Column = NamedArgumentExpression(name, e) - def distributedIndex(): Column = { - val expr = MonotonicallyIncreasingID() - expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") - expr - } - @scala.annotation.varargs def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) From 9af705d27cae1ce9918f0467ecff6da10b311ab6 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 17 Oct 2024 11:32:53 +0900 Subject: [PATCH 27/31] [SPARK-49951][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_(1099|3085) ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for _LEGACY_ERROR_TEMP_(1099|3085) ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48449 from itholic/SPARK-49951. Authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- .../resources/error/error-conditions.json | 16 ++++++--------- .../spark/sql/avro/AvroDataToCatalyst.scala | 20 +++++++------------ .../spark/sql/avro/AvroFunctionsSuite.scala | 11 ++++++++++ .../sql/errors/QueryCompilationErrors.scala | 10 ++++------ .../expressions/CsvExpressionsSuite.scala | 17 ++++++++++------ .../apache/spark/sql/CsvFunctionsSuite.scala | 8 +++----- .../apache/spark/sql/JsonFunctionsSuite.scala | 8 +++----- .../execution/datasources/xml/XmlSuite.scala | 8 +++----- 8 files changed, 48 insertions(+), 50 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index fdc00549cc088..3e4848658f14a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3827,6 +3827,12 @@ ], "sqlState" : "42617" }, + "PARSE_MODE_UNSUPPORTED" : { + "message" : [ + "The function doesn't support the mode. Acceptable modes are PERMISSIVE and FAILFAST." + ], + "sqlState" : "42601" + }, "PARSE_SYNTAX_ERROR" : { "message" : [ "Syntax error at or near ." @@ -6045,11 +6051,6 @@ "DataType '' is not supported by ." ] }, - "_LEGACY_ERROR_TEMP_1099" : { - "message" : [ - "() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_1103" : { "message" : [ "Unsupported component type in arrays." @@ -8096,11 +8097,6 @@ "No handler for UDF/UDAF/UDTF '': " ] }, - "_LEGACY_ERROR_TEMP_3085" : { - "message" : [ - "from_avro() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_3086" : { "message" : [ "Cannot persist into Hive metastore as table property keys may not start with 'spark.sql.': " diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 0b85b208242cb..9c8b2d0375588 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -24,10 +24,10 @@ import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ private[sql] case class AvroDataToCatalyst( @@ -80,12 +80,9 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val parseMode: ParseMode = { val mode = avroOptions.parseMode if (mode != PermissiveMode && mode != FailFastMode) { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, mode + ) } mode } @@ -123,12 +120,9 @@ private[sql] case class AvroDataToCatalyst( s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) case _ => - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> parseMode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, parseMode + ) } } } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index a7f7abadcf485..096cdfe0b9ee4 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -106,6 +106,17 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { functions.from_avro( $"avro", avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), expected) + + checkError( + exception = intercept[AnalysisException] { + avroStructDF.select( + functions.from_avro( + $"avro", avroTypeStruct, Map("mode" -> "DROPMALFORMED").asJava)).collect() + }, + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_avro`", + "mode" -> "DROPMALFORMED")) } test("roundtrip in to_avro and from_avro - array with null") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 9dc15c4a1b78d..431983214c482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InputParameter, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ParseMode} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction} @@ -1341,12 +1341,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def parseModeUnsupportedError(funcName: String, mode: ParseMode): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1099", + errorClass = "PARSE_MODE_UNSUPPORTED", messageParameters = Map( - "funcName" -> funcName, - "mode" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + "funcName" -> toSQLId(funcName), + "mode" -> mode.name)) } def nonFoldableArgumentError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index a89cb58c3e03b..249975f9c0d4c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -149,12 +149,17 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P test("unsupported mode") { val csvData = "---" val schema = StructType(StructField("a", DoubleType) :: Nil) - val exception = intercept[TestFailedException] { - checkEvaluation( - CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), UTC_OPT), - InternalRow(null)) - }.getCause - assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + + checkError( + exception = intercept[TestFailedException] { + checkEvaluation( + CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), UTC_OPT), + InternalRow(null)) + }.getCause.asInstanceOf[AnalysisException], + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_csv`", + "mode" -> "DROPMALFORMED")) } test("infer schema of CSV strings") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index e6907b8656482..970ed5843b3c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -352,12 +352,10 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_csv($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "from_csv", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> "FAILFAST")) + "funcName" -> "`from_csv`", + "mode" -> "DROPMALFORMED")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7b19ad988d308..84408d8e2495d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -861,12 +861,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "from_json", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> "FAILFAST")) + "funcName" -> "`from_json`", + "mode" -> "DROPMALFORMED")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 91f21c4a2ed34..059e4aadef2bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1315,12 +1315,10 @@ class XmlSuite spark.sql(s"""SELECT schema_of_xml('1', map('mode', 'DROPMALFORMED'))""") .collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "schema_of_xml", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> FailFastMode.name) + "funcName" -> "`schema_of_xml`", + "mode" -> "DROPMALFORMED") ) } From 948aeba93e1a5898ec4f8e71ff4eb89e7514c43f Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 17 Oct 2024 11:33:01 +0900 Subject: [PATCH 28/31] [SPARK-49947][SQL][TESTS] Upgrade `MsSql` docker image version ### What changes were proposed in this pull request? The pr aims to upgrade the `MsSql` docker image version from `2022-CU14-ubuntu-22.04` to `2022-CU15-ubuntu-22.04`. ### Why are the changes needed? This will help Apache Spark test the latest `MsSql`. https://hub.docker.com/r/microsoft/mssql-server image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48444 from panbingkun/SPARK-49947. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala | 2 +- .../apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala | 4 ++-- .../spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 4 ++-- .../apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala index 9d3c7d1eca328..6bd33356cab3d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc class MsSQLServerDatabaseOnDocker extends DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04") + "mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04") override val env = Map( "SA_PASSWORD" -> "Sapass123", "ACCEPT_EULA" -> "Y" diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 90cd68e6e1d24..62f088ebc2b6d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.{BinaryType, DecimalType} import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index aaaaa28558342..d884ad4c62466 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MsSqlServerIntegrationSuite" * }}} */ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index 9fb3bc4fba945..724c394a4f052 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" * }}} */ From 070f2bdfb968c8080de1c6614c1def978df823d4 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Thu, 17 Oct 2024 11:35:14 +0900 Subject: [PATCH 29/31] [SPARK-49876][CONNECT] Get rid of global locks from Spark Connect Service ### What changes were proposed in this pull request? Get rid of global locks from Spark Connect Service. - ServerSideListenerHolder: AtomicReference replaces the global lock. - SparkConnectStreamingQueryCache: two global locks are replaced with ConcurrentHashMap and a mutex-protected per-tag data structure, i.e., global locks -> a per-tag lock. ### Why are the changes needed? Spark Connect Service doesn't limit the number of threads, susceptible to priority inversion because of heavy use of global locks. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests + modified an existing test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48350 from changgyoopark-db/SPARK-49876-REMOVE-LOCKS. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../SparkConnectListenerBusListener.scala | 22 +- .../SparkConnectStreamingQueryCache.scala | 239 ++++++++++-------- ...SparkConnectListenerBusListenerSuite.scala | 3 +- ...SparkConnectStreamingQueryCacheSuite.scala | 14 +- 4 files changed, 160 insertions(+), 118 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 7a0c067ab430b..445f40d25edcd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.atomic.AtomicReference import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,7 +42,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { // The server side listener that is responsible to stream streaming query events back to client. // There is only one listener per sessionHolder, but each listener is responsible for all events // of all streaming queries in the SparkSession. - var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None + var streamingQueryServerSideListener: AtomicReference[SparkConnectListenerBusListener] = + new AtomicReference() // The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent. // Events for corresponding query will be sent back to client with // the WriteStreamOperationStart response, so that the client can handle the event before @@ -50,10 +52,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { val streamingQueryStartedEventCache : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap() - val lock = new Object() - - def isServerSideListenerRegistered: Boolean = lock.synchronized { - streamingQueryServerSideListener.isDefined + def isServerSideListenerRegistered: Boolean = { + streamingQueryServerSideListener.getAcquire() != null } /** @@ -65,10 +65,10 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * @param responseObserver * the responseObserver created from the first long running executeThread. */ - def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized { + def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val serverListener = new SparkConnectListenerBusListener(this, responseObserver) sessionHolder.session.streams.addListener(serverListener) - streamingQueryServerSideListener = Some(serverListener) + streamingQueryServerSideListener.setRelease(serverListener) } /** @@ -77,13 +77,13 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * exception. It removes the listener from the session, clears the cache. Also it sends back the * final ResultComplete response. */ - def cleanUp(): Unit = lock.synchronized { - streamingQueryServerSideListener.foreach { listener => + def cleanUp(): Unit = { + var listener = streamingQueryServerSideListener.getAndSet(null) + if (listener != null) { sessionHolder.session.streams.removeListener(listener) listener.sendResultComplete() + streamingQueryStartedEventCache.clear() } - streamingQueryStartedEventCache.clear() - streamingQueryServerSideListener = None } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 48492bac62344..3da2548b456e8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.Executors -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} @@ -61,36 +58,34 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, query: StreamingQuery, tags: Set[String], - operationId: String): Unit = queryCacheLock.synchronized { - taggedQueriesLock.synchronized { - val value = QueryCacheValue( - userId = sessionHolder.userId, - sessionId = sessionHolder.sessionId, - session = sessionHolder.session, - query = query, - operationId = operationId, - expiresAtMs = None) - - val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) - tags.foreach { tag => - taggedQueries - .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]) - .addOne(queryKey) - } - - queryCache.put(queryKey, value) match { - case Some(existing) => // Query is being replace. Not really expected. + operationId: String): Unit = { + val value = QueryCacheValue( + userId = sessionHolder.userId, + sessionId = sessionHolder.sessionId, + session = sessionHolder.session, + query = query, + operationId = operationId, + expiresAtMs = None) + + val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + + queryCache.compute( + queryKey, + (key, existing) => { + if (existing != null) { // The query is being replaced: allowed, though not expected. logWarning(log"Replacing existing query in the cache (unexpected). " + log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " + log"new value ${MDC(NEW_VALUE, value)}.") - case None => + } else { logInfo( log"Adding new query to the cache. Query Id ${MDC(QUERY_ID, query.id)}, " + log"value ${MDC(QUERY_CACHE_VALUE, value)}.") - } + } + value + }) - schedulePeriodicChecks() // Starts the scheduler thread if it hasn't started. - } + schedulePeriodicChecks() // Start the scheduler thread if it has not been started. } /** @@ -104,44 +99,35 @@ private[connect] class SparkConnectStreamingQueryCache( runId: String, tags: Set[String], session: SparkSession): Option[QueryCacheValue] = { - taggedQueriesLock.synchronized { - val key = QueryCacheKey(queryId, runId) - val result = getCachedQuery(QueryCacheKey(queryId, runId), session) - tags.foreach { tag => - taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key) - } - result - } + val queryKey = QueryCacheKey(queryId, runId) + val result = getCachedQuery(QueryCacheKey(queryId, runId), session) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + result } /** * Similar with [[getCachedQuery]] but it gets queries tagged previously. */ def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = { - taggedQueriesLock.synchronized { - taggedQueries - .get(tag) - .map { k => - k.flatMap(getCachedQuery(_, session)).toSeq - } - .getOrElse(Seq.empty[QueryCacheValue]) - } + val queryKeySet = Option(taggedQueries.get(tag)) + queryKeySet + .map(_.flatMap(k => getCachedQuery(k, session))) + .getOrElse(Seq.empty[QueryCacheValue]) } private def getCachedQuery( key: QueryCacheKey, session: SparkSession): Option[QueryCacheValue] = { - queryCacheLock.synchronized { - queryCache.get(key).flatMap { v => - if (v.session == session) { - v.expiresAtMs.foreach { _ => - // Extend the expiry time as the client is accessing it. - val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis - queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) - } - Some(v) - } else None // Should be rare, may be client is trying access from a different session. - } + val value = Option(queryCache.get(key)) + value.flatMap { v => + if (v.session == session) { + v.expiresAtMs.foreach { _ => + // Extend the expiry time as the client is accessing it. + val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis + queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) + } + Some(v) + } else None // Should be rare, may be client is trying access from a different session. } } @@ -154,7 +140,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, blocking: Boolean = true): Seq[String] = { val operationIds = new mutable.ArrayBuffer[String]() - for ((k, v) <- queryCache) { + queryCache.forEach((k, v) => { if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { logInfo( @@ -178,29 +164,27 @@ private[connect] class SparkConnectStreamingQueryCache( } } } - } + }) operationIds.toSeq } // Visible for testing private[service] def getCachedValue(queryId: String, runId: String): Option[QueryCacheValue] = - queryCache.get(QueryCacheKey(queryId, runId)) + Option(queryCache.get(QueryCacheKey(queryId, runId))) // Visible for testing. - private[service] def shutdown(): Unit = queryCacheLock.synchronized { + private[service] def shutdown(): Unit = { val executor = scheduledExecutor.getAndSet(null) if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } } - @GuardedBy("queryCacheLock") - private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue] - private val queryCacheLock = new Object + private val queryCache: ConcurrentMap[QueryCacheKey, QueryCacheValue] = + new ConcurrentHashMap[QueryCacheKey, QueryCacheValue] - @GuardedBy("queryCacheLock") - private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] - private val taggedQueriesLock = new Object + private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] = + new ConcurrentHashMap[String, QueryCacheKeySet] private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = new AtomicReference[ScheduledExecutorService]() @@ -228,62 +212,109 @@ private[connect] class SparkConnectStreamingQueryCache( } } + private def addTaggedQuery(tag: String, queryKey: QueryCacheKey): Unit = { + taggedQueries.compute( + tag, + (k, v) => { + if (v == null || !v.addKey(queryKey)) { + // Create a new QueryCacheKeySet if the entry is absent or being removed. + var keys = mutable.HashSet.empty[QueryCacheKey] + keys.add(queryKey) + new QueryCacheKeySet(keys = keys) + } else { + v + } + }) + } + /** * Periodic maintenance task to do the following: * - Update status of query if it is inactive. Sets an expiry time for such queries * - Drop expired queries from the cache. */ - private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized { + private def periodicMaintenance(): Unit = { + val nowMs = clock.getTimeMillis() - queryCacheLock.synchronized { - val nowMs = clock.getTimeMillis() + queryCache.forEach((k, v) => { + val id = k.queryId + val runId = k.runId + v.expiresAtMs match { - for ((k, v) <- queryCache) { - val id = k.queryId - val runId = k.runId - v.expiresAtMs match { + case Some(ts) if nowMs >= ts => // Expired. Drop references. + logInfo( + log"Removing references for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") + queryCache.remove(k) - case Some(ts) if nowMs >= ts => // Expired. Drop references. - logInfo( - log"Removing references for id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") - queryCache.remove(k) + case Some(_) => // Inactive query waiting for expiration. Do nothing. + logInfo( + log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)}") + + case None => // Active query, check if it is stopped. Enable timeout if it is stopped. + val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - case Some(_) => // Inactive query waiting for expiration. Do nothing. + if (!isActive) { logInfo( - log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"Marking query id: ${MDC(QUERY_ID, id)} " + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)}") - - case None => // Active query, check if it is stopped. Enable timeout if it is stopped. - val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - - if (!isActive) { - logInfo( - log"Marking query id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") - val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis - queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) - // To consider: Clean up any runner registered for this query with the session holder - // for this session. Useful in case listener events are delayed (such delays are - // seen in practice, especially when users have heavy processing inside listeners). - // Currently such workers would be cleaned up when the connect session expires. - } - } + log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") + val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis + queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) + // To consider: Clean up any runner registered for this query with the session holder + // for this session. Useful in case listener events are delayed (such delays are + // seen in practice, especially when users have heavy processing inside listeners). + // Currently such workers would be cleaned up when the connect session expires. + } } + }) - taggedQueries.toArray.foreach { case (key, value) => - value.zipWithIndex.toArray.foreach { case (queryKey, i) => - if (queryCache.contains(queryKey)) { - value.remove(i) - } + // Removes any tagged queries that do not correspond to cached queries. + taggedQueries.forEach((key, value) => { + if (value.filter(k => queryCache.containsKey(k))) { + taggedQueries.remove(key, value) + } + }) + } + + case class QueryCacheKeySet(keys: mutable.HashSet[QueryCacheKey]) { + + /** Tries to add the key if the set is not empty, otherwise returns false. */ + def addKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.isEmpty) { + // The entry is about to be removed. + return false } + keys.add(key) + true + } + } - if (value.isEmpty) { - taggedQueries.remove(key) + /** Removes the key and returns true if the set is empty. */ + def removeKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.remove(key)) { + return keys.isEmpty } + false + } + } + + /** Removes entries that do not satisfy the predicate. */ + def filter(pred: QueryCacheKey => Boolean): Boolean = { + keys.synchronized { + keys.filterInPlace(k => pred(k)) + keys.isEmpty + } + } + + /** Iterates over entries, apply the function individually, and then flatten the result. */ + def flatMap[T](function: QueryCacheKey => Option[T]): Seq[T] = { + keys.synchronized { + keys.flatMap(k => function(k)).toSeq } } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala index d856ffaabc316..2404dea21d91e 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala @@ -202,7 +202,8 @@ class SparkConnectListenerBusListenerSuite val listenerHolder = sessionHolder.streamingServersideListenerHolder eventually(timeout(5.seconds), interval(500.milliseconds)) { assert( - sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty) + sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.get() == + null) assert(spark.streams.listListeners().size === listenerCntBeforeThrow) assert(listenerHolder.streamingQueryStartedEventCache.isEmpty) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index 512a0a80c4a91..729a995f46145 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -48,6 +48,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug val queryId = UUID.randomUUID().toString val runId = UUID.randomUUID().toString + val tag = "test_tag" val mockSession = mock[SparkSession] val mockQuery = mock[StreamingQuery] val mockStreamingQueryManager = mock[StreamingQueryManager] @@ -67,13 +68,16 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug // Register the query. - sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set.empty[String], "") + sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set(tag), "") sessionMgr.getCachedValue(queryId, runId) match { case Some(v) => assert(v.sessionId == sessionHolder.sessionId) assert(v.expiresAtMs.isEmpty, "No expiry time should be set for active query") + val taggedQueries = sessionMgr.getTaggedQuery(tag, mockSession) + assert(taggedQueries.contains(v)) + case None => assert(false, "Query should be found") } @@ -127,6 +131,9 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug assert(sessionMgr.getCachedValue(queryId, runId).map(_.query).contains(mockQuery)) assert( sessionMgr.getCachedValue(queryId, restartedRunId).map(_.query).contains(restartedQuery)) + eventually(timeout(1.minute)) { + assert(sessionMgr.taggedQueries.containsKey(tag)) + } // Advance time by 1 minute and verify the first query is dropped from the cache. clock.advance(1.minute.toMillis) @@ -144,8 +151,11 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug clock.advance(1.minute.toMillis) eventually(timeout(1.minute)) { assert(sessionMgr.getCachedValue(queryId, restartedRunId).isEmpty) + assert(sessionMgr.getTaggedQuery(tag, mockSession).isEmpty) + } + eventually(timeout(1.minute)) { + assert(!sessionMgr.taggedQueries.containsKey(tag)) } - sessionMgr.shutdown() } } From e374b94a9c8b217156ce24137efbd404a38e4f21 Mon Sep 17 00:00:00 2001 From: Ziqi Liu Date: Thu, 17 Oct 2024 12:24:59 +0800 Subject: [PATCH 30/31] [SPARK-49979][SQL] Fix AQE hanging issue when collecting twice on a failed plan ### What changes were proposed in this pull request? Record failure/error status in query stage. And abort immediately upon seeing failed query stage when creating new query stages. ### Why are the changes needed? AQE has a potential hanging issue when we collect twice from a failed AQE plan, no new query stage will be created, and no stage will be submitted either. We will be waiting for a finish event forever, which will never come because that query stage has already failed in the previous run. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? New UT. ### Was this patch authored or co-authored using generative AI tooling? NO Closes #48484 from liuzqt/SPARK-49979. Authored-by: Ziqi Liu Signed-off-by: Wenchen Fan --- .../adaptive/AdaptiveSparkPlanExec.scala | 12 +++++++++++ .../execution/adaptive/QueryStageExec.scala | 9 ++++++++ .../adaptive/AdaptiveQueryExecSuite.scala | 21 +++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index ffab67b7cae24..77efc4793359f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -340,6 +340,7 @@ case class AdaptiveSparkPlanExec( }(AdaptiveSparkPlanExec.executionContext) } catch { case e: Throwable => + stage.error.set(Some(e)) cleanUpAndThrowException(Seq(e), Some(stage.id)) } } @@ -355,6 +356,7 @@ case class AdaptiveSparkPlanExec( case StageSuccess(stage, res) => stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => + stage.error.set(Some(ex)) errors.append(ex) } @@ -600,6 +602,7 @@ case class AdaptiveSparkPlanExec( newStages = Seq(newStage)) case q: QueryStageExec => + assertStageNotFailed(q) CreateStageResult(newPlan = q, allChildStagesMaterialized = q.isMaterialized, newStages = Seq.empty) @@ -815,6 +818,15 @@ case class AdaptiveSparkPlanExec( } } + private def assertStageNotFailed(stage: QueryStageExec): Unit = { + if (stage.hasFailed) { + throw stage.error.get().get match { + case fatal: SparkFatalException => fatal.throwable + case other => other + } + } + } + /** * Cancel all running stages with best effort and throw an Exception containing all stage * materialization errors and stage cancellation errors. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 51595e20ae5f8..2391fe740118d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -93,6 +93,13 @@ abstract class QueryStageExec extends LeafExecNode { private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption final def isMaterialized: Boolean = resultOption.get().isDefined + @transient + @volatile + protected var _error = new AtomicReference[Option[Throwable]](None) + + def error: AtomicReference[Option[Throwable]] = _error + final def hasFailed: Boolean = _error.get().isDefined + override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning override def outputOrdering: Seq[SortOrder] = plan.outputOrdering @@ -203,6 +210,7 @@ case class ShuffleQueryStageExec( ReusedExchangeExec(newOutput, shuffle), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } @@ -249,6 +257,7 @@ case class BroadcastQueryStageExec( ReusedExchangeExec(newOutput, broadcast), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 8e9ba6c8e21d8..1df045764d8b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -3065,6 +3065,27 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-49979: AQE hang forever when collecting twice on a failed AQE plan") { + val func: Long => Boolean = (i : Long) => { + throw new Exception("SPARK-49979") + } + withUserDefinedFunction("func" -> true) { + spark.udf.register("func", func) + val df1 = spark.range(1024).select($"id".as("key1")) + val df2 = spark.range(2048).select($"id".as("key2")) + .withColumn("group_key", $"key2" % 1024) + val df = df1.filter(expr("func(key1)")).hint("MERGE").join(df2, $"key1" === $"key2") + .groupBy($"group_key").agg("key1" -> "count") + intercept[Throwable] { + df.collect() + } + // second collect should not hang forever + intercept[Throwable] { + df.collect() + } + } + } } /** From 175d56310fae187247dc240ed6694ea667201cf2 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Thu, 17 Oct 2024 12:27:02 +0800 Subject: [PATCH 31/31] [SPARK-49977][SQL] Use stack-based iterative computation to avoid creating many Scala List objects for deep expression trees ### What changes were proposed in this pull request? In some use cases with deep expression trees, the driver's heap shows many `scala.collection.immutable.$colon$colon` objects from the heap. The objects are allocated due to deep recursion in the `gatherCommutative` method which uses `flatmap` recursively. Each invocation of `flatmap` creates a new temporary Scala collection. Our claim is based on the following stack trace (>1K lines) of a thread in the driver below, truncated here for brevity: ``` "HiveServer2-Background-Pool: Thread-9867" #9867 daemon prio=5 os_prio=0 tid=0x00007f35080bf000 nid=0x33e7 runnable [0x00007f3393372000] java.lang.Thread.State: RUNNABLE at scala.collection.immutable.List$Appender$1.apply(List.scala:350) at scala.collection.immutable.List$Appender$1.apply(List.scala:341) at scala.collection.immutable.List.flatMap(List.scala:431) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.gatherCommutative(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.$anonfun$gatherCommutative$1(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression$$Lambda$5280/143713747.apply(Unknown Source) at scala.collection.immutable.List.flatMap(List.scala:366) .... at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.gatherCommutative(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression.$anonfun$gatherCommutative$1(Expression.scala:1479) at org.apache.spark.sql.catalyst.expressions.CommutativeExpression$$Lambda$5280/143713747.apply(Unknown Source) at scala.collection.immutable.List.flatMap(List.scala:366) .... ``` This PR fixes the issue by using a stack-based iterative computation, completely avoiding the creation of temporary Scala objects. ### Why are the changes needed? Reduce heap usage of the driver ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests, refactor ### Was this patch authored or co-authored using generative AI tooling? No Closes #48481 from utkarsh39/SPARK-49977. Lead-authored-by: Utkarsh Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/Expression.scala | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a57ba2aaa569..bb32e518ec39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1347,9 +1347,21 @@ trait CommutativeExpression extends Expression { /** Collects adjacent commutative operations. */ private def gatherCommutative( e: Expression, - f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = e match { - case c: CommutativeExpression if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) - case other => other.canonicalized :: Nil + f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = { + val resultBuffer = scala.collection.mutable.Buffer[Expression]() + val stack = scala.collection.mutable.Stack[Expression](e) + + // [SPARK-49977]: Use iterative approach to avoid creating many temporary List objects + // for deep expression trees through recursion. + while (stack.nonEmpty) { + stack.pop() match { + case c: CommutativeExpression if f.isDefinedAt(c) => + stack.pushAll(f(c)) + case other => + resultBuffer += other.canonicalized + } + } + resultBuffer.toSeq } /**