From a9127068194a48786df4f429ceb4f908c71f7138 Mon Sep 17 00:00:00 2001 From: chenyu <119398199+chenyu-opensource@users.noreply.github.com> Date: Wed, 8 Nov 2023 19:16:48 +0800 Subject: [PATCH 01/15] [SPARK-45829][DOCS] Update the default value for spark.executor.logs.rolling.maxSize **What changes were proposed in this pull request?** The PR updates the default value of 'spark.executor.logs.rolling.maxSize' in configuration.html on the website **Why are the changes needed?** The default value of 'spark.executor.logs.rolling.maxSize' is 1024 * 1024, but the website is wrong. **Does this PR introduce any user-facing change?** No **How was this patch tested?** It doesn't need to. **Was this patch authored or co-authored using generative AI tooling?** No Closes #43712 from chenyu-opensource/branch-SPARK-45829. Authored-by: chenyu <119398199+chenyu-opensource@users.noreply.github.com> Signed-off-by: Kent Yao --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 60cad24e71c44..3d54aaf6518be 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -684,7 +684,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.logs.rolling.maxSize - (none) + 1024 * 1024 Set the max size of the file in bytes by which the executor logs will be rolled over. Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles From 1d8df4f6b99b836f4267b888e81d67c75b4dfdcd Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 8 Nov 2023 19:43:33 +0800 Subject: [PATCH 02/15] [SPARK-45606][SQL] Release restrictions on multi-layer runtime filter ### What changes were proposed in this pull request? Before https://github.com/apache/spark/pull/39170, Spark only supports insert runtime filter for application side of shuffle join on single-layer. Considered it's not worth to insert more runtime filter if the column already exists runtime filter, Spark restricts it at https://github.com/apache/spark/blob/7057952f6bc2c5cf97dd408effd1b18bee1cb8f4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala#L346 For example `select * from bf1 join bf2 on bf1.c1 = bf2.c2 and bf1.c1 = bf2.b2 where bf2.a2 = 62` This SQL have two join conditions. There will insert two runtime filter on `bf1.c1` if haven't the restriction mentioned above. At that time, it was reasonable. After https://github.com/apache/spark/pull/39170, Spark supports insert runtime filter for one side of any shuffle join on multi-layer. But the restrictions on multi-layer runtime filter mentioned above looks outdated. For example `select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5` Assume bf2 as the build side and insert a runtime filter for bf1. We can't insert the same runtime filter for bf3 due to there are already a runtime filter on `bf1.c1`. The behavior is different from the origin and is unexpected. The change of the PR doesn't affect the restriction mentioned above. ### Why are the changes needed? Release restrictions on multi-layer runtime filter. Expand optimization surface. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? Test cases updated. Micro benchmark for q9 in TPC-H. **TPC-H 100** Query | Master(ms) | PR(ms) | Difference(ms) | Percent -- | -- | -- | -- | -- q9 | 26491 | 20725 | 5766| 27.82% ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #43449 from beliefer/SPARK-45606. Authored-by: Jiaan Geng Signed-off-by: Jiaan Geng --- .../optimizer/InjectRuntimeFilter.scala | 33 +++++++++---------- .../spark/sql/InjectRuntimeFilterSuite.scala | 8 ++--- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala index 5f5508d6b22c2..9c150f1f3308f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -247,15 +247,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J } } - private def hasBloomFilter( - left: LogicalPlan, - right: LogicalPlan, - leftKey: Expression, - rightKey: Expression): Boolean = { - findBloomFilterWithKey(left, leftKey) || findBloomFilterWithKey(right, rightKey) - } - - private def findBloomFilterWithKey(plan: LogicalPlan, key: Expression): Boolean = { + private def hasBloomFilter(plan: LogicalPlan, key: Expression): Boolean = { plan.exists { case Filter(condition, _) => splitConjunctivePredicates(condition).exists { @@ -277,28 +269,33 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J leftKeys.lazyZip(rightKeys).foreach((l, r) => { // Check if: // 1. There is already a DPP filter on the key - // 2. There is already a bloom filter on the key - // 3. The keys are simple cheap expressions + // 2. The keys are simple cheap expressions if (filterCounter < numFilterThreshold && !hasDynamicPruningSubquery(left, right, l, r) && - !hasBloomFilter(newLeft, newRight, l, r) && isSimpleExpression(l) && isSimpleExpression(r)) { val oldLeft = newLeft val oldRight = newRight - // Check if the current join is a shuffle join or a broadcast join that - // has a shuffle below it + // Check if: + // 1. The current join type supports prune the left side with runtime filter + // 2. The current join is a shuffle join or a broadcast join that + // has a shuffle below it + // 3. There is no bloom filter on the left key yet val hasShuffle = isProbablyShuffleJoin(left, right, hint) - if (canPruneLeft(joinType) && (hasShuffle || probablyHasShuffle(left))) { + if (canPruneLeft(joinType) && (hasShuffle || probablyHasShuffle(left)) && + !hasBloomFilter(newLeft, l)) { extractBeneficialFilterCreatePlan(left, right, l, r).foreach { case (filterCreationSideKey, filterCreationSidePlan) => newLeft = injectFilter(l, newLeft, filterCreationSideKey, filterCreationSidePlan) } } // Did we actually inject on the left? If not, try on the right - // Check if the current join is a shuffle join or a broadcast join that - // has a shuffle below it + // Check if: + // 1. The current join type supports prune the right side with runtime filter + // 2. The current join is a shuffle join or a broadcast join that + // has a shuffle below it + // 3. There is no bloom filter on the right key yet if (newLeft.fastEquals(oldLeft) && canPruneRight(joinType) && - (hasShuffle || probablyHasShuffle(right))) { + (hasShuffle || probablyHasShuffle(right)) && !hasBloomFilter(newRight, r)) { extractBeneficialFilterCreatePlan(right, left, r, l).foreach { case (filterCreationSideKey, filterCreationSidePlan) => newRight = injectFilter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index 2e57975ee6d1d..fc1524be13179 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -335,14 +335,12 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2) assertRewroteWithBloomFilter("select * from (select * from bf1 right join bf2 on " + "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2) - // Can't leverage the transitivity of join keys due to runtime filters already exists. - // bf2 as creation side and inject runtime filter for bf1. assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + - "and bf3.c3 = bf1.c1 where bf2.a2 = 5") + "and bf3.c3 = bf1.c1 where bf2.a2 = 5", 2) assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join bf3 on " + - "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5") + "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5", 2) assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 join bf3 on " + - "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5") + "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5", 2) } withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", From f866549a5aa86f379cb71732b97fa547f2c4eb0a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 8 Nov 2023 20:21:55 +0800 Subject: [PATCH 03/15] [SPARK-45816][SQL] Return `NULL` when overflowing during casting from timestamp to integers ### What changes were proposed in this pull request? Spark cast works in two modes: ansi and non-ansi. When overflowing during casting, the common behavior under non-ansi mode is to return null. However, casting from Timestamp to Int/Short/Byte returns a wrapping value now. The behavior to silently overflow doesn't make sense. This patch changes it to the common behavior, i.e., returning null. ### Why are the changes needed? Returning a wrapping value, e.g., negative one, during casting Timestamp to Int/Short/Byte could implicitly cause misinterpret casted result without caution. We also should follow the common behavior of overflowing handling. ### Does this PR introduce _any_ user-facing change? Yes. Overflowing during casting from Timestamp to Int/Short/Byte under non-ansi mode, returns null instead of wrapping value. ### How was this patch tested? Will add test or update test if any existing ones fail ### Was this patch authored or co-authored using generative AI tooling? No Closes #43694 from viirya/fix_cast_integers. Authored-by: Liang-Chi Hsieh Signed-off-by: Jiaan Geng --- docs/sql-migration-guide.md | 1 + .../spark/sql/catalyst/expressions/Cast.scala | 51 +++++++++++-------- .../expressions/CastWithAnsiOffSuite.scala | 6 +-- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index b0dc49ed47683..5c00ce6558513 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -28,6 +28,7 @@ license: | - Since Spark 4.0, any read of SQL tables takes into consideration the SQL configs `spark.sql.files.ignoreCorruptFiles`/`spark.sql.files.ignoreMissingFiles` instead of the core config `spark.files.ignoreCorruptFiles`/`spark.files.ignoreMissingFiles`. - Since Spark 4.0, `spark.sql.hive.metastore` drops the support of Hive prior to 2.0.0 as they require JDK 8 that Spark does not support anymore. Users should migrate to higher versions. - Since Spark 4.0, `spark.sql.parquet.compression.codec` drops the support of codec name `lz4raw`, please use `lz4_raw` instead. +- Since Spark 4.0, when overflowing during casting timestamp to byte/short/int under non-ansi mode, Spark will return null instead a wrapping value. ## Upgrading from Spark SQL 3.4 to 3.5 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 62295fe260535..ee022c068b987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -344,6 +344,7 @@ object Cast extends QueryErrorsBase { case (StringType, _) => true case (_, StringType) => false + case (TimestampType, ByteType | ShortType | IntegerType) => true case (FloatType | DoubleType, TimestampType) => true case (TimestampType, DateType) => false case (_, DateType) => true @@ -777,6 +778,14 @@ case class Cast( buildCast[Int](_, i => yearMonthIntervalToInt(i, x.startField, x.endField).toLong) } + private def errorOrNull(t: Any, from: DataType, to: DataType) = { + if (ansiEnabled) { + throw QueryExecutionErrors.castingCauseOverflowError(t, from, to) + } else { + null + } + } + // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType if ansiEnabled => @@ -788,17 +797,15 @@ case class Cast( buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toInt) { longValue.toInt } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, from, IntegerType) + errorOrNull(t, from, IntegerType) } }) - case TimestampType => - buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => exactNumeric.toInt(b) @@ -826,17 +833,15 @@ case class Cast( buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toShort) { longValue.toShort } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, from, ShortType) + errorOrNull(t, from, ShortType) } }) - case TimestampType => - buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -875,17 +880,15 @@ case class Cast( buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) case DateType => buildCast[Int](_, d => null) - case TimestampType if ansiEnabled => + case TimestampType => buildCast[Long](_, t => { val longValue = timestampToLong(t) if (longValue == longValue.toByte) { longValue.toByte } else { - throw QueryExecutionErrors.castingCauseOverflowError(t, from, ByteType) + errorOrNull(t, from, ByteType) } }) - case TimestampType => - buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType if ansiEnabled => val exactNumeric = PhysicalNumericType.exactNumeric(x) b => @@ -1661,22 +1664,26 @@ case class Cast( integralType: String, from: DataType, to: DataType): CastFunction = { - if (ansiEnabled) { - val longValue = ctx.freshName("longValue") - val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) - val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) - (c, evPrim, _) => - code""" + + val longValue = ctx.freshName("longValue") + val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName) + val toDt = ctx.addReferenceObj("to", to, to.getClass.getName) + + (c, evPrim, evNull) => + val overflow = if (ansiEnabled) { + code"""throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);""" + } else { + code"$evNull = true;" + } + + code""" long $longValue = ${timestampToLongCode(c)}; if ($longValue == ($integralType) $longValue) { $evPrim = ($integralType) $longValue; } else { - throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt); + $overflow } """ - } else { - (c, evPrim, _) => code"$evPrim = ($integralType) ${timestampToLongCode(c)};" - } } private[this] def castDayTimeIntervalToIntegralTypeCode( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala index 1dbf03b1538a6..e260b6fdbdb52 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOffSuite.scala @@ -514,9 +514,9 @@ class CastWithAnsiOffSuite extends CastSuiteBase { val negativeTs = Timestamp.valueOf("1900-05-05 18:34:56.1") assert(negativeTs.getTime < 0) val expectedSecs = Math.floorDiv(negativeTs.getTime, MILLIS_PER_SECOND) - checkEvaluation(cast(negativeTs, ByteType), expectedSecs.toByte) - checkEvaluation(cast(negativeTs, ShortType), expectedSecs.toShort) - checkEvaluation(cast(negativeTs, IntegerType), expectedSecs.toInt) + checkEvaluation(cast(negativeTs, ByteType), null) + checkEvaluation(cast(negativeTs, ShortType), null) + checkEvaluation(cast(negativeTs, IntegerType), null) checkEvaluation(cast(negativeTs, LongType), expectedSecs) } } From 6abc4a1a58ef4e5d896717b10b2314dae2af78af Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 8 Nov 2023 15:51:50 +0300 Subject: [PATCH 04/15] [SPARK-45841][SQL] Expose stack trace by `DataFrameQueryContext` ### What changes were proposed in this pull request? In the PR, I propose to change the case class `DataFrameQueryContext`, and add stack traces as a field and override `callSite`, `fragment` using the new field `stackTrace`. ### Why are the changes needed? By exposing the stack trace, we give users opportunity to see all stack traces needed for debugging. ### Does this PR introduce _any_ user-facing change? No, `DataFrameQueryContext` hasn't been released yet. ### How was this patch tested? By running the modified test suite: ``` $ build/sbt "test:testOnly *DatasetSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43703 from MaxGekk/stack-traces-in-DataFrameQueryContext. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../sql/catalyst/trees/QueryContexts.scala | 33 ++++++++----------- .../org/apache/spark/sql/DatasetSuite.scala | 13 +++++--- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 8d885d07ca8b0..874c834b75585 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -134,9 +134,7 @@ case class SQLQueryContext( override def callSite: String = throw new UnsupportedOperationException } -case class DataFrameQueryContext( - override val fragment: String, - override val callSite: String) extends QueryContext { +case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends QueryContext { override val contextType = QueryContextType.DataFrame override def objectType: String = throw new UnsupportedOperationException @@ -144,6 +142,19 @@ case class DataFrameQueryContext( override def startIndex: Int = throw new UnsupportedOperationException override def stopIndex: Int = throw new UnsupportedOperationException + override val fragment: String = { + stackTrace.headOption.map { firstElem => + val methodName = firstElem.getMethodName + if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + }.getOrElse("") + } + + override val callSite: String = stackTrace.tail.headOption.map(_.toString).getOrElse("") + override lazy val summary: String = { val builder = new StringBuilder builder ++= "== DataFrame ==\n" @@ -157,19 +168,3 @@ case class DataFrameQueryContext( builder.result() } } - -object DataFrameQueryContext { - def apply(elements: Array[StackTraceElement]): DataFrameQueryContext = { - val fragment = elements.headOption.map { firstElem => - val methodName = firstElem.getMethodName - if (methodName.length > 1 && methodName(0) == '$') { - methodName.substring(1) - } else { - methodName - } - }.getOrElse("") - val callSite = elements.tail.headOption.map(_.toString).getOrElse("") - - DataFrameQueryContext(fragment, callSite) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 66105d2ac429f..dcbd8948120ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoders, ExpressionEncod import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BoxedIntEncoder import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRowWithSchema} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.trees.DataFrameQueryContext import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2668,16 +2669,18 @@ class DatasetSuite extends QueryTest withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { val df = Seq(1).toDS() var callSitePattern: String = null + val exception = intercept[AnalysisException] { + callSitePattern = getNextLineCallSitePattern() + val c = col("a") + df.select(c) + } checkError( - exception = intercept[AnalysisException] { - callSitePattern = getNextLineCallSitePattern() - val c = col("a") - df.select(c) - }, + exception, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), context = ExpectedContext(fragment = "col", callSitePattern = callSitePattern)) + assert(exception.context.head.asInstanceOf[DataFrameQueryContext].stackTrace.length == 2) } } } From b5408e1ce61ce2195de72dcf79d8355c16b4b92a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 Nov 2023 08:24:52 -0800 Subject: [PATCH 05/15] [SPARK-45828][SQL] Remove deprecated method in dsl ### What changes were proposed in this pull request? The pr aims to remove `some deprecated method` in dsl. ### Why are the changes needed? After https://github.com/apache/spark/pull/36646 (Apache Spark 3.4.0), the method `def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()` and `def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)` has been marked as `deprecated` and we need to remove it in `Spark 4.0`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43708 from panbingkun/SPARK-45828. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .../scala/org/apache/spark/sql/catalyst/dsl/package.scala | 6 ------ .../sql/catalyst/optimizer/TransposeWindowSuite.scala | 8 ++++---- .../spark/sql/catalyst/plans/LogicalPlanSuite.scala | 6 +++--- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5f85716fa2833..30d4c2dbb409f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -152,9 +152,6 @@ package object dsl { def desc: SortOrder = SortOrder(expr, Descending) def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() - // TODO: Remove at Spark 4.0.0 - @deprecated("Use as(alias: String)", "3.4.0") - def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } trait ExpressionConversions { @@ -468,9 +465,6 @@ package object dsl { limit: Int): LogicalPlan = WindowGroupLimit(partitionSpec, orderSpec, rankLikeFunction, limit, logicalPlan) - // TODO: Remove at Spark 4.0.0 - @deprecated("Use subquery(alias: String)", "3.4.0") - def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def subquery(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala index 8d4c2de10e34f..f4d520bbb4439 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -146,15 +146,15 @@ class TransposeWindowSuite extends PlanTest { test("SPARK-38034: transpose two adjacent windows with compatible partitions " + "which is not a prefix") { val query = testRelation - .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec4, orderSpec2) - .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec3, orderSpec1) + .window(Seq(sum(c).as("sum_a_2")), partitionSpec4, orderSpec2) + .window(Seq(sum(c).as("sum_a_1")), partitionSpec3, orderSpec1) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) val correctAnswer = testRelation - .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec3, orderSpec1) - .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec4, orderSpec2) + .window(Seq(sum(c).as("sum_a_1")), partitionSpec3, orderSpec1) + .window(Seq(sum(c).as("sum_a_2")), partitionSpec4, orderSpec2) .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("sum_a_2"), Symbol("sum_a_1")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index ea0fcac881c7a..3eba9eebc3d5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -126,12 +126,12 @@ class LogicalPlanSuite extends SparkFunSuite { assert(sort2.maxRows === Some(100)) assert(sort2.maxRowsPerPartition === Some(100)) - val c1 = Literal(1).as(Symbol("a")).toAttribute.newInstance().withNullability(true) - val c2 = Literal(2).as(Symbol("b")).toAttribute.newInstance().withNullability(true) + val c1 = Literal(1).as("a").toAttribute.newInstance().withNullability(true) + val c2 = Literal(2).as("b").toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), Symbol("b")), Seq(Symbol("a"), Literal(null))), Seq(c1, c2), - sort.select(Symbol("id") as Symbol("a"), Symbol("id") + 1 as Symbol("b"))) + sort.select(Symbol("id") as "a", Symbol("id") + 1 as "b")) assert(expand.maxRows === Some(200)) assert(expand.maxRowsPerPartition === Some(68)) From e331de06dd0526761c804b32640e3471ce772d38 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 8 Nov 2023 08:26:23 -0800 Subject: [PATCH 06/15] [MINOR][CORE][SQL] Clean up expired comments: `Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation.` ### What changes were proposed in this pull request? This pr just clean up expired comments: `Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation.` ### Why are the changes needed? Apache Spark 4.0 only support Scala 2.13, so these comments are no longer needed ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? No testing required ### Was this patch authored or co-authored using generative AI tooling? No Closes #43718 from LuciferYang/minor-comments. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/util/BoundedPriorityQueue.scala | 2 -- .../org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala | 2 -- .../apache/spark/sql/catalyst/expressions/AttributeMap.scala | 2 -- .../apache/spark/sql/execution/streaming/StreamProgress.scala | 2 -- 4 files changed, 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index ccb4d2063ff3b..9fed2373ea552 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -31,8 +31,6 @@ import scala.jdk.CollectionConverters._ private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) extends Iterable[A] with Growable[A] with Serializable { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - private val underlying = new JPriorityQueue[A](maxSize, ord) override def iterator: Iterator[A] = underlying.iterator.asScala diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala index e18a01810d2eb..640304efce4b4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala @@ -30,8 +30,6 @@ import java.util.Locale class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T] with Serializable { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT))) override def get(k: String): Option[T] = keyLowerCasedMap.get(k.toLowerCase(Locale.ROOT)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index ac6149f3acc4d..b317cacc061b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -41,8 +41,6 @@ object AttributeMap { class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) extends Map[Attribute, A] with Serializable { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) override def getOrElse[B1 >: A](k: Attribute, default: => B1): B1 = get(k).getOrElse(default) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 6aa1b46cbb94a..02f52bb30e1f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -29,8 +29,6 @@ class StreamProgress( new immutable.HashMap[SparkDataStream, OffsetV2]) extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { - // Note: this class supports Scala 2.13. A parallel source tree has a 2.12 implementation. - def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } From 9d93b7112a31965447a34301889f90d14578e628 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 8 Nov 2023 09:23:12 -0800 Subject: [PATCH 07/15] [SPARK-45639][SQL][PYTHON] Support loading Python data sources in DataFrameReader ### What changes were proposed in this pull request? This PR supports `spark.read.format(...).load()` for Python data sources. After this PR, users can use a Python data source directly like this: ```python from pyspark.sql.datasource import DataSource, DataSourceReader class MyReader(DataSourceReader): def read(self, partition): yield (0, 1) class MyDataSource(DataSource): classmethod def name(cls): return "my-source" def schema(self): return "id INT, value INT" def reader(self, schema): return MyReader() spark.dataSource.register(MyDataSource) df = spark.read.format("my-source").load() df.show() +---+-----+ | id|value| +---+-----+ | 0| 1| +---+-----+ ``` ### Why are the changes needed? To support Python data sources. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can load a custom Python data source using `spark.read.format(...).load()`. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43630 from allisonwang-db/spark-45639-ds-lookup. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 12 +++ dev/sparktestsupport/modules.py | 1 + docs/sql-error-conditions.md | 12 +++ python/pyspark/sql/session.py | 4 + .../sql/tests/test_python_datasource.py | 97 +++++++++++++++++-- .../pyspark/sql/worker/create_data_source.py | 16 ++- .../sql/errors/QueryCompilationErrors.scala | 12 +++ .../apache/spark/sql/DataFrameReader.scala | 48 +++++++-- .../datasources/DataSourceManager.scala | 31 +++++- .../python/UserDefinedPythonDataSource.scala | 15 ++- .../python/PythonDataSourceSuite.scala | 35 +++++++ 11 files changed, 255 insertions(+), 28 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index db46ee8ca208c..c38171c3d9e63 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -850,6 +850,12 @@ ], "sqlState" : "42710" }, + "DATA_SOURCE_NOT_EXIST" : { + "message" : [ + "Data source '' not found. Please make sure the data source is registered." + ], + "sqlState" : "42704" + }, "DATA_SOURCE_NOT_FOUND" : { "message" : [ "Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`." @@ -1095,6 +1101,12 @@ ], "sqlState" : "42809" }, + "FOUND_MULTIPLE_DATA_SOURCES" : { + "message" : [ + "Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath." + ], + "sqlState" : "42710" + }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ "A column cannot have both a default value and a generation expression but column has default value: () and generation expression: ()." diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 95c9069a83131..01757ba28dd23 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -511,6 +511,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_window", "pyspark.sql.tests.pandas.test_converter", "pyspark.sql.tests.test_pandas_sqlmetrics", + "pyspark.sql.tests.test_python_datasource", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 7b0bc8ceb2b5a..8a5faa15dc9cd 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -454,6 +454,12 @@ DataType `` requires a length parameter, for example ``(10). Please Data source '``' already exists in the registry. Please use a different name for the new data source. +### DATA_SOURCE_NOT_EXIST + +[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Data source '``' not found. Please make sure the data source is registered. + ### DATA_SOURCE_NOT_FOUND [SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -669,6 +675,12 @@ No such struct field `` in ``. The operation `` is not allowed on the ``: ``. +### FOUND_MULTIPLE_DATA_SOURCES + +[SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Detected multiple data sources with the name '``'. Please check the data source isn't simultaneously registered and located in the classpath. + ### GENERATED_COLUMN_WITH_DEFAULT_VALUE [SQLSTATE: 42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4ab7281d7ac87..85aff09aa3df1 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -884,6 +884,10 @@ def dataSource(self) -> "DataSourceRegistration": Returns ------- :class:`DataSourceRegistration` + + Notes + ----- + This feature is experimental and unstable. """ from pyspark.sql.datasource import DataSourceRegistration diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index b429d73fb7d77..fe6a841752746 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -14,10 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import unittest from pyspark.sql.datasource import DataSource, DataSourceReader +from pyspark.sql.types import Row +from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import SPARK_HOME class BasePythonDataSourceTestsMixin: @@ -45,16 +49,93 @@ def read(self, partition): self.assertEqual(list(reader.partitions()), [None]) self.assertEqual(list(reader.read(None)), [(None,)]) - def test_register_data_source(self): - class MyDataSource(DataSource): - ... + def test_in_memory_data_source(self): + class InMemDataSourceReader(DataSourceReader): + DEFAULT_NUM_PARTITIONS: int = 3 + + def __init__(self, paths, options): + self.paths = paths + self.options = options + + def partitions(self): + if "num_partitions" in self.options: + num_partitions = int(self.options["num_partitions"]) + else: + num_partitions = self.DEFAULT_NUM_PARTITIONS + return range(num_partitions) + + def read(self, partition): + yield partition, str(partition) + + class InMemoryDataSource(DataSource): + @classmethod + def name(cls): + return "memory" + + def schema(self): + return "x INT, y STRING" + + def reader(self, schema) -> "DataSourceReader": + return InMemDataSourceReader(self.paths, self.options) + + self.spark.dataSource.register(InMemoryDataSource) + df = self.spark.read.format("memory").load() + self.assertEqual(df.rdd.getNumPartitions(), 3) + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")]) - self.spark.dataSource.register(MyDataSource) + df = self.spark.read.format("memory").option("num_partitions", 2).load() + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) + self.assertEqual(df.rdd.getNumPartitions(), 2) + + def test_custom_json_data_source(self): + import json + + class JsonDataSourceReader(DataSourceReader): + def __init__(self, paths, options): + self.paths = paths + self.options = options + + def partitions(self): + return iter(self.paths) + + def read(self, path): + with open(path, "r") as file: + for line in file.readlines(): + if line.strip(): + data = json.loads(line) + yield data.get("name"), data.get("age") + + class JsonDataSource(DataSource): + @classmethod + def name(cls): + return "my-json" + + def schema(self): + return "name STRING, age INT" + + def reader(self, schema) -> "DataSourceReader": + return JsonDataSourceReader(self.paths, self.options) + + self.spark.dataSource.register(JsonDataSource) + path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") + path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") + df1 = self.spark.read.format("my-json").load(path1) + self.assertEqual(df1.rdd.getNumPartitions(), 1) + assertDataFrameEqual( + df1, + [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], + ) - self.assertTrue( - self.spark._jsparkSession.sharedState() - .dataSourceRegistry() - .dataSourceExists("MyDataSource") + df2 = self.spark.read.format("my-json").load([path1, path2]) + self.assertEqual(df2.rdd.getNumPartitions(), 2) + assertDataFrameEqual( + df2, + [ + Row(name="Michael", age=None), + Row(name="Andy", age=30), + Row(name="Justin", age=19), + Row(name="Jonathan", age=None), + ], ) diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index ea56d2cc75221..6a9ef79b7c18d 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import inspect import os import sys from typing import IO, List from pyspark.accumulators import _accumulatorRegistry -from pyspark.errors import PySparkAssertionError, PySparkRuntimeError +from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, @@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None: }, ) + # Check the name method is a class method. + if not inspect.ismethod(data_source_cls.name): + raise PySparkTypeError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "'name()' method to be a classmethod", + "actual": f"'{type(data_source_cls.name).__name__}'", + }, + ) + # Receive the provider name. provider = utf8_deserializer.loads(infile) + + # Check if the provider name matches the data source's name. if provider.lower() != data_source_cls.name().lower(): raise PySparkAssertionError( error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1925eddd2ce23..0c5dcb1ead01e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3805,4 +3805,16 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat errorClass = "DATA_SOURCE_ALREADY_EXISTS", messageParameters = Map("provider" -> name)) } + + def dataSourceDoesNotExist(name: String): Throwable = { + new AnalysisException( + errorClass = "DATA_SOURCE_NOT_EXIST", + messageParameters = Map("provider" -> name)) + } + + def foundMultipleDataSources(provider: String): Throwable = { + new AnalysisException( + errorClass = "FOUND_MULTIPLE_DATA_SOURCES", + messageParameters = Map("provider" -> provider)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9992d8cbba076..ef447e8a80102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.util.{Locale, Properties, ServiceConfigurationError} import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} -import org.apache.spark.Partition +import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable} import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError() } - DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider => - DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions, - source, paths: _*) - }.getOrElse(loadV1Source(paths: _*)) + val isUserDefinedDataSource = + sparkSession.sharedState.dataSourceManager.dataSourceExists(source) + + Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { + case Success(providerOpt) => + // The source can be successfully loaded as either a V1 or a V2 data source. + // Check if it is also a user-defined data source. + if (isUserDefinedDataSource) { + throw QueryCompilationErrors.foundMultipleDataSources(source) + } + providerOpt.flatMap { provider => + DataSourceV2Utils.loadV2Source( + sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*) + }.getOrElse(loadV1Source(paths: _*)) + case Failure(exception) => + // Exceptions are thrown while trying to load the data source as a V1 or V2 data source. + // For the following not found exceptions, if the user-defined data source is defined, + // we can instead return the user-defined data source. + val isNotFoundError = exception match { + case _: NoClassDefFoundError | _: SparkClassNotFoundException => true + case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND" + case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError] + case _ => false + } + if (isNotFoundError && isUserDefinedDataSource) { + loadUserDefinedDataSource(paths) + } else { + // Throw the original exception. + throw exception + } + } + } + + private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { + val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) + // Unless the legacy path option behavior is enabled, the extraOptions here + // should not include "path" or "paths" as keys. + val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions) + Dataset.ofRows(sparkSession, plan) } private def loadV1Source(paths: String*) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 283ca2ac62edc..72a9e6497aca5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -22,10 +22,14 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap +/** + * A manager for user-defined data sources. It is used to register and lookup data sources by + * their short names or fully qualified names. + */ class DataSourceManager { private type DataSourceBuilder = ( @@ -33,22 +37,41 @@ class DataSourceManager { String, // provider name Seq[String], // paths Option[StructType], // user specified schema - CaseInsensitiveStringMap // options + CaseInsensitiveMap[String] // options ) => LogicalPlan private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]() private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) + /** + * Register a data source builder for the given provider. + * Note that the provider name is case-insensitive. + */ def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { val normalizedName = normalize(name) if (dataSourceBuilders.containsKey(normalizedName)) { throw QueryCompilationErrors.dataSourceAlreadyExists(name) } - // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using loadDataSource. dataSourceBuilders.put(normalizedName, builder) } - def dataSourceExists(name: String): Boolean = + /** + * Returns a data source builder for the given provider and throw an exception if + * it does not exist. + */ + def lookupDataSource(name: String): DataSourceBuilder = { + if (dataSourceExists(name)) { + dataSourceBuilders.get(normalize(name)) + } else { + throw QueryCompilationErrors.dataSourceDoesNotExist(name) + } + } + + /** + * Checks if a data source with the specified name exists (case-insensitive). + */ + def dataSourceExists(name: String): Boolean = { dataSourceBuilders.containsKey(normalize(name)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index dbff8eefcd5fb..703c1e10ce265 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python import java.io.{DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler @@ -28,9 +27,9 @@ import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, SimplePyt import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSource} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A user-defined Python data source. This is used by the Python API. @@ -44,7 +43,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, paths: Seq[String], userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap): LogicalPlan = { + options: CaseInsensitiveMap[String]): LogicalPlan = { val runner = new UserDefinedPythonDataSourceRunner( dataSourceCls, provider, paths, userSpecifiedSchema, options) @@ -70,7 +69,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, paths: Seq[String] = Seq.empty, userSpecifiedSchema: Option[StructType] = None, - options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): DataFrame = { + options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options) Dataset.ofRows(sparkSession, plan) } @@ -91,7 +90,7 @@ class UserDefinedPythonDataSourceRunner( provider: String, paths: Seq[String], userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap) + options: CaseInsensitiveMap[String]) extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) { override val workerModule = "pyspark.sql.worker.create_data_source" @@ -113,9 +112,9 @@ class UserDefinedPythonDataSourceRunner( // Send the options dataOut.writeInt(options.size) - options.entrySet.asScala.foreach { e => - PythonWorkerUtils.writeUTF(e.getKey, dataOut) - PythonWorkerUtils.writeUTF(e.getValue, dataOut) + options.iterator.foreach { case (key, value) => + PythonWorkerUtils.writeUTF(key, dataOut) + PythonWorkerUtils.writeUTF(value, dataOut) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 6c749c2c9b67a..22a1e5250cd95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -155,6 +155,41 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { parameters = Map("provider" -> dataSourceName)) } + test("load data source") { + assume(shouldTestPythonUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def __init__(self, paths, options): + | self.paths = paths + | self.options = options + | + | def partitions(self): + | return iter(self.paths) + | + | def read(self, path): + | yield (path, 1) + | + |class $dataSourceName(DataSource): + | @classmethod + | def name(cls) -> str: + | return "test" + | + | def schema(self) -> str: + | return "id STRING, value INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader(self.paths, self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython("test", dataSource) + + checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) + checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) + checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) + } + test("reader not implemented") { assume(shouldTestPythonUDFs) val dataSourceScript = From eabea643c7424340397fc91dd89329baf31b48dd Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 Nov 2023 14:58:36 -0600 Subject: [PATCH 08/15] [SPARK-42821][SQL] Remove unused parameters in splitFiles methods ### What changes were proposed in this pull request? The pr aims to remove unused parameters in PartitionedFileUtil.splitFiles methods ### Why are the changes needed? Make the code more concise. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. Closes #40454 from panbingkun/minor_PartitionedFileUtil. Authored-by: panbingkun Signed-off-by: Sean Owen --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 3 +-- .../org/apache/spark/sql/execution/PartitionedFileUtil.scala | 2 -- .../apache/spark/sql/execution/datasources/v2/FileScan.scala | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e5a38967dc3e1..c7bb3b6719157 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -687,7 +687,7 @@ case class FileSourceScanExec( * @param selectedPartitions Hive-style partition that are part of the read. */ private def createReadRDD( - readFile: (PartitionedFile) => Iterator[InternalRow], + readFile: PartitionedFile => Iterator[InternalRow], selectedPartitions: Array[PartitionDirectory]): RDD[InternalRow] = { val openCostInBytes = relation.sparkSession.sessionState.conf.filesOpenCostInBytes val maxSplitBytes = @@ -711,7 +711,6 @@ case class FileSourceScanExec( val isSplitable = relation.fileFormat.isSplitable( relation.sparkSession, relation.options, file.getPath) PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, file = file, isSplitable = isSplitable, maxSplitBytes = maxSplitBytes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala index cc234565d1112..b31369b6768e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/PartitionedFileUtil.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.execution import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus} import org.apache.spark.paths.SparkPath -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources._ object PartitionedFileUtil { def splitFiles( - sparkSession: SparkSession, file: FileStatusWithMetadata, isSplitable: Boolean, maxSplitBytes: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 71e86beefdaff..61d61ee7af250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -152,7 +152,6 @@ trait FileScan extends Scan } partition.files.flatMap { file => PartitionedFileUtil.splitFiles( - sparkSession = sparkSession, file = file, isSplitable = isSplitable(file.getPath), maxSplitBytes = maxSplitBytes, From 4df4fec622f3f6926b979f89daa177ec5e53d4ad Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 8 Nov 2023 13:07:04 -0800 Subject: [PATCH 09/15] [SPARK-45843][CORE] Support `killall` in REST Submission API ### What changes were proposed in this pull request? This PR aims to add `killall` action in REST Submission API. ### Why are the changes needed? To help users to kill all submissions easily. **BEFORE: Script** ```bash for id in $(curl http://master:8080/json/activedrivers | grep id | sed 's/"/ /g' | awk '{print $3}') do curl -XPOST http://master:6066/v1/submissions/kill/$id done ``` **AFTER** ```bash $ curl -XPOST http://master:6066/v1/submissions/killall { "action" : "KillAllSubmissionResponse", "message" : "Kill request for all drivers submitted", "serverSparkVersion" : "4.0.0-SNAPSHOT", "success" : true } ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43721 from dongjoon-hyun/SPARK-45843. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../apache/spark/deploy/DeployMessage.scala | 6 ++++ .../apache/spark/deploy/master/Master.scala | 25 +++++++++++++ .../deploy/rest/RestSubmissionClient.scala | 35 +++++++++++++++++++ .../deploy/rest/RestSubmissionServer.scala | 23 +++++++++++- .../deploy/rest/StandaloneRestServer.scala | 19 ++++++++++ .../rest/SubmitRestProtocolResponse.scala | 10 ++++++ .../rest/StandaloneRestSubmitSuite.scala | 31 ++++++++++++++++ 7 files changed, 148 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index f49530461b4d0..4ccc0bd7cdc26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -233,6 +233,12 @@ private[deploy] object DeployMessages { master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage + case object RequestKillAllDrivers extends DeployMessage + + case class KillAllDriversResponse( + master: RpcEndpointRef, success: Boolean, message: String) + extends DeployMessage + case class RequestDriverStatus(driverId: String) extends DeployMessage case class DriverStatusResponse(found: Boolean, state: Option[DriverState], diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 3ba50318610ba..dbb647252c5f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -460,6 +460,31 @@ private[deploy] class Master( } } + case RequestKillAllDrivers => + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillAllDriversResponse(self, success = false, msg)) + } else { + logInfo("Asked to kill all drivers") + drivers.foreach { d => + val driverId = d.id + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + logInfo(s"Kill request for $driverId submitted") + } + context.reply(KillAllDriversResponse(self, true, "Kill request for all drivers submitted")) + } + case RequestClearCompletedDriversAndApps => val numDrivers = completedDrivers.length val numApps = completedApps.length diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 3010efc936f97..286305bb76b84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -135,6 +135,35 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { response } + /** Request that the server kill all submissions. */ + def killAllSubmissions(): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill all submissions in $master.") + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getKillAllUrl(m) + try { + response = post(url) + response match { + case k: KillAllSubmissionResponse => + if (!Utils.responseFromBackup(k.message)) { + handleRestResponse(k) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } + } + response + } + /** Request that the server clears all submissions and applications. */ def clear(): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to clear $master.") @@ -329,6 +358,12 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { new URL(s"$baseUrl/kill/$submissionId") } + /** Return the REST URL for killing all submissions. */ + private def getKillAllUrl(master: String): URL = { + val baseUrl = getBaseUrl(master) + new URL(s"$baseUrl/killall") + } + /** Return the REST URL for clear all existing submissions and applications. */ private def getClearUrl(master: String): URL = { val baseUrl = getBaseUrl(master) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 3323d0f529ebf..28197fd0a556d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -54,6 +54,7 @@ private[spark] abstract class RestSubmissionServer( protected val submitRequestServlet: SubmitRequestServlet protected val killRequestServlet: KillRequestServlet + protected val killAllRequestServlet: KillAllRequestServlet protected val statusRequestServlet: StatusRequestServlet protected val clearRequestServlet: ClearRequestServlet @@ -64,6 +65,7 @@ private[spark] abstract class RestSubmissionServer( protected lazy val contextToServlet = Map[String, RestServlet]( s"$baseContext/create/*" -> submitRequestServlet, s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/killall/*" -> killAllRequestServlet, s"$baseContext/status/*" -> statusRequestServlet, s"$baseContext/clear/*" -> clearRequestServlet, "/*" -> new ErrorServlet // default handler @@ -229,6 +231,25 @@ private[rest] abstract class KillRequestServlet extends RestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse } +/** + * A servlet for handling killAll requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillAllRequestServlet extends RestServlet { + + /** + * Have the Master kill all drivers and return an appropriate response to the client. + * Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val responseMessage = handleKillAll() + sendResponse(responseMessage, response) + } + + protected def handleKillAll(): KillAllSubmissionResponse +} + /** * A servlet for handling clear requests passed to the [[RestSubmissionServer]]. */ @@ -331,7 +352,7 @@ private class ErrorServlet extends RestServlet { "Missing the /submissions prefix." case `serverVersion` :: "submissions" :: tail => // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, /clear or /status." + "Missing an action: please specify one of /create, /kill, /killall, /clear or /status." case unknownVersion :: tail => // http://host:port/unknown-version/* versionMismatch = true diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 8ed716428dc28..d382ec12847dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -63,6 +63,8 @@ private[deploy] class StandaloneRestServer( new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = new StandaloneKillRequestServlet(masterEndpoint, masterConf) + protected override val killAllRequestServlet = + new StandaloneKillAllRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = new StandaloneStatusRequestServlet(masterEndpoint, masterConf) protected override val clearRequestServlet = @@ -87,6 +89,23 @@ private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, } } +/** + * A servlet for handling killAll requests passed to the [[StandaloneRestServer]]. + */ +private[rest] class StandaloneKillAllRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) + extends KillAllRequestServlet { + + protected def handleKillAll() : KillAllSubmissionResponse = { + val response = masterEndpoint.askSync[DeployMessages.KillAllDriversResponse]( + DeployMessages.RequestKillAllDrivers) + val k = new KillAllSubmissionResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.success = response.success + k + } +} + /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 21614c22285f8..b9e3b3028ac79 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -55,6 +55,16 @@ private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { } } +/** + * A response to a killAll request in the REST application submission protocol. + */ +private[spark] class KillAllSubmissionResponse extends SubmitRestProtocolResponse { + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success, "success") + } +} + /** * A response to a clear request in the REST application submission protocol. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index d775aa6542dcd..1cc2c873760df 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -236,6 +236,15 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { assert(clearResponse.success) } + test("SPARK-45843: killAll") { + val masterUrl = startDummyServer() + val response = new RestSubmissionClient(masterUrl).killAllSubmissions() + val killAllResponse = getKillAllResponse(response) + assert(killAllResponse.action === Utils.getFormattedClassName(killAllResponse)) + assert(killAllResponse.serverSparkVersion === SPARK_VERSION) + assert(killAllResponse.success) + } + /* ---------------------------------------- * | Aberrant client / server behavior | * ---------------------------------------- */ @@ -514,6 +523,16 @@ class StandaloneRestSubmitSuite extends SparkFunSuite { } } + /** Return the response as a killAll response, or fail with error otherwise. */ + private def getKillAllResponse(response: SubmitRestProtocolResponse) + : KillAllSubmissionResponse = { + response match { + case k: KillAllSubmissionResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") + case r => fail(s"Expected killAll response. Actual: ${r.toJson}") + } + } + /** Return the response as a clear response, or fail with error otherwise. */ private def getClearResponse(response: SubmitRestProtocolResponse): ClearResponse = { response match { @@ -590,6 +609,8 @@ private class DummyMaster( context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage)) case RequestKillDriver(driverId) => context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) + case RequestKillAllDrivers => + context.reply(KillAllDriversResponse(self, success = true, killMessage)) case RequestDriverStatus(driverId) => context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) case RequestClearCompletedDriversAndApps => @@ -636,6 +657,7 @@ private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEn * * When handling a submit request, the server returns a malformed JSON. * When handling a kill request, the server returns an invalid JSON. + * When handling a killAll request, the server returns an invalid JSON. * When handling a status request, the server throws an internal exception. * When handling a clear request, the server throws an internal exception. * The purpose of this class is to test that client handles these cases gracefully. @@ -650,6 +672,7 @@ private class FaultyStandaloneRestServer( protected override val submitRequestServlet = new MalformedSubmitServlet protected override val killRequestServlet = new InvalidKillServlet + protected override val killAllRequestServlet = new InvalidKillAllServlet protected override val statusRequestServlet = new ExplodingStatusServlet protected override val clearRequestServlet = new ExplodingClearServlet @@ -673,6 +696,14 @@ private class FaultyStandaloneRestServer( } } + /** A faulty servlet that produces invalid responses. */ + class InvalidKillAllServlet extends StandaloneKillAllRequestServlet(masterEndpoint, masterConf) { + protected override def handleKillAll(): KillAllSubmissionResponse = { + val k = super.handleKillAll() + k + } + } + /** A faulty status servlet that explodes. */ class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) { private def explode: Int = 1 / 0 From dfd7cde91c5d6f034a11ea492be83afaf771ceb6 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Wed, 8 Nov 2023 18:22:40 -0800 Subject: [PATCH 10/15] [SPARK-45842][SQL] Refactor Catalog Function APIs to use analyzer ### What changes were proposed in this pull request? - Refactor Catalog Function APIs to use analyzer ### Why are the changes needed? - Less duplicate logics. We should not directly invoke catalog APIs, but go through analyzer. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43720 from heyihong/SPARK-45842. Authored-by: Yihong He Signed-off-by: Dongjoon Hyun --- .../spark/sql/internal/CatalogImpl.scala | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 5650e9d2399cc..b1ad454fc041f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -22,14 +22,14 @@ import scala.util.control.NonFatal import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, FunctionCatalog, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.ShowTablesCommand @@ -284,6 +284,33 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(functions.result(), sparkSession) } + private def toFunctionIdent(functionName: String): Seq[String] = { + val parsed = parseIdent(functionName) + // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in + // the Hive Metastore first. + if (parsed.length <= 2 && + !sessionCatalog.isTemporaryFunction(parsed.asFunctionIdentifier) && + sessionCatalog.isPersistentFunction(parsed.asFunctionIdentifier)) { + qualifyV1Ident(parsed) + } else { + parsed + } + } + + private def functionExists(ident: Seq[String]): Boolean = { + val plan = + UnresolvedFunctionName(ident, CatalogImpl.FUNCTION_EXISTS_COMMAND_NAME, false, None) + try { + sparkSession.sessionState.executePlan(plan).analyzed match { + case _: ResolvedPersistentFunc => true + case _: ResolvedNonPersistentFunc => true + case _ => false + } + } catch { + case e: AnalysisException if e.getErrorClass == "UNRESOLVED_ROUTINE" => false + } + } + private def makeFunction(ident: Seq[String]): Function = { val plan = UnresolvedFunctionName(ident, "Catalog.makeFunction", false, None) sparkSession.sessionState.executePlan(plan).analyzed match { @@ -465,17 +492,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * function. This throws an `AnalysisException` when no `Function` can be found. */ override def getFunction(functionName: String): Function = { - val parsed = parseIdent(functionName) - // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in - // the Hive Metastore first. - val nameParts = if (parsed.length <= 2 && - !sessionCatalog.isTemporaryFunction(parsed.asFunctionIdentifier) && - sessionCatalog.isPersistentFunction(parsed.asFunctionIdentifier)) { - qualifyV1Ident(parsed) - } else { - parsed - } - makeFunction(nameParts) + makeFunction(toFunctionIdent(functionName)) } /** @@ -540,23 +557,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * or a function. */ override def functionExists(functionName: String): Boolean = { - val parsed = parseIdent(functionName) - // For backward compatibility (Spark 3.3 and prior), we should check if the function exists in - // the Hive Metastore first. This also checks if it's a built-in/temp function. - (parsed.length <= 2 && sessionCatalog.functionExists(parsed.asFunctionIdentifier)) || { - val plan = UnresolvedIdentifier(parsed) - sparkSession.sessionState.executePlan(plan).analyzed match { - case ResolvedIdentifier(catalog: FunctionCatalog, ident) => catalog.functionExists(ident) - case _ => false - } - } + functionExists(toFunctionIdent(functionName)) } /** * Checks if the function with the specified name exists in the specified database. */ override def functionExists(dbName: String, functionName: String): Boolean = { - sessionCatalog.functionExists(FunctionIdentifier(functionName, Option(dbName))) + // For backward compatibility (Spark 3.3 and prior), here we always look up the function from + // the Hive Metastore. + functionExists(Seq(CatalogManager.SESSION_CATALOG_NAME, dbName, functionName)) } /** @@ -942,4 +952,5 @@ private[sql] object CatalogImpl { new Dataset[T](queryExecution, enc) } + private val FUNCTION_EXISTS_COMMAND_NAME = "Catalog.functionExists" } From 974313994a0594fde7b424e569febed89cafd9ca Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 8 Nov 2023 19:22:13 -0800 Subject: [PATCH 11/15] [SPARK-45835][INFRA] Make gitHub labeler more accurate and remove outdated comments ### What changes were proposed in this pull request? The pr aims to make gitHub labeler more accurate and remove outdated comments. ### Why are the changes needed? The functions mentioned in the comments have been released in the latest version of Github Action labeler. https://github.com/actions/labeler/issues/111 https://github.com/actions/labeler/issues/111#issuecomment-1345989028 image According to the description of the original PR (https://github.com/apache/spark/pull/30244/files), after 'any/all' is released in the official version of `Github Action labeler`, we need to make subsequent updates to better identify the code `label`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Continuous manual observation is required. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43716 from panbingkun/SPARK-45835. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .github/labeler.yml | 40 +++++------------------------------ .github/workflows/labeler.yml | 13 ------------ 2 files changed, 5 insertions(+), 48 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index f21b90d460fb6..fc69733f4b66a 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -17,23 +17,6 @@ # under the License. # -# -# Pull Request Labeler Github Action Configuration: https://github.com/marketplace/actions/labeler -# -# Note that we currently cannot use the negatioon operator (i.e. `!`) for miniglob matches as they -# would match any file that doesn't touch them. What's needed is the concept of `any `, which takes a -# list of constraints / globs and then matches all of the constraints for either `any` of the files or -# `all` of the files in the change set. -# -# However, `any`/`all` are not supported in a released version and testing off of the `main` branch -# resulted in some other errors when testing. -# -# An issue has been opened upstream requesting that a release be cut that has support for all/any: -# - https://github.com/actions/labeler/issues/111 -# -# While we wait for this issue to be handled upstream, we can remove -# the negated / `!` matches for now and at least have labels again. -# INFRA: - ".github/**/*" - "appveyor.yml" @@ -45,9 +28,7 @@ INFRA: - "dev/merge_spark_pr.py" - "dev/run-tests-jenkins*" BUILD: - # Can be supported when a stable release with correct all/any is released - #- any: ['dev/**/*', '!dev/merge_spark_pr.py', '!dev/.rat-excludes'] - - "dev/**/*" + - any: ['dev/**/*', '!dev/merge_spark_pr.py', '!dev/run-tests-jenkins*'] - "build/**/*" - "project/**/*" - "assembly/**/*" @@ -55,22 +36,16 @@ BUILD: - "bin/docker-image-tool.sh" - "bin/find-spark-home*" - "scalastyle-config.xml" - # These can be added in the above `any` clause (and the /dev/**/* glob removed) when - # `any`/`all` support is released - # - "!dev/merge_spark_pr.py" - # - "!dev/run-tests-jenkins*" - # - "!dev/.rat-excludes" DOCS: - "docs/**/*" - "**/README.md" - "**/CONTRIBUTING.md" + - "python/docs/**/*" EXAMPLES: - "examples/**/*" - "bin/run-example*" -# CORE needs to be updated when all/any are released upstream. CORE: - # - any: ["core/**/*", "!**/*UI.scala", "!**/ui/**/*"] # If any file matches all of the globs defined in the list started by `any`, label is applied. - - "core/**/*" + - any: ["core/**/*", "!**/*UI.scala", "!**/ui/**/*"] - "common/kvstore/**/*" - "common/network-common/**/*" - "common/network-shuffle/**/*" @@ -82,12 +57,8 @@ SPARK SHELL: - "repl/**/*" - "bin/spark-shell*" SQL: -#- any: ["**/sql/**/*", "!python/pyspark/sql/avro/**/*", "!python/pyspark/sql/streaming/**/*", "!python/pyspark/sql/tests/streaming/test_streaming.py"] - - "**/sql/**/*" + - any: ["**/sql/**/*", "!python/pyspark/sql/avro/**/*", "!python/pyspark/sql/streaming/**/*", "!python/pyspark/sql/tests/streaming/test_streaming*.py"] - "common/unsafe/**/*" - #- "!python/pyspark/sql/avro/**/*" - #- "!python/pyspark/sql/streaming/**/*" - #- "!python/pyspark/sql/tests/streaming/test_streaming.py" - "bin/spark-sql*" - "bin/beeline*" - "sbin/*thriftserver*.sh" @@ -123,7 +94,7 @@ STRUCTURED STREAMING: - "**/sql/**/streaming/**/*" - "connector/kafka-0-10-sql/**/*" - "python/pyspark/sql/streaming/**/*" - - "python/pyspark/sql/tests/streaming/test_streaming.py" + - "python/pyspark/sql/tests/streaming/test_streaming*.py" - "**/*streaming.R" PYTHON: - "bin/pyspark*" @@ -148,7 +119,6 @@ DEPLOY: - "sbin/**/*" CONNECT: - "connector/connect/**/*" - - "**/sql/sparkconnect/**/*" - "python/pyspark/sql/**/connect/**/*" - "python/pyspark/ml/**/connect/**/*" PROTOBUF: diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index c6b6e65bc9fec..b55d28e5a6406 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -34,19 +34,6 @@ jobs: contents: read pull-requests: write steps: - # In order to get back the negated matches like in the old config, - # we need the actinons/labeler concept of `all` and `any` which matches - # all of the given constraints / glob patterns for either `all` - # files or `any` file in the change set. - # - # Github issue which requests a timeline for a release with any/all support: - # - https://github.com/actions/labeler/issues/111 - # This issue also references the issue that mentioned that any/all are only - # supported on main branch (previously called master): - # - https://github.com/actions/labeler/issues/73#issuecomment-639034278 - # - # However, these are not in a published release and the current `main` branch - # has some issues upon testing. - uses: actions/labeler@v4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" From 093fbf1aa8520193b8d929f9f855afe0aded20a1 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 8 Nov 2023 19:23:29 -0800 Subject: [PATCH 12/15] [SPARK-45831][CORE][SQL][DSTREAM] Use collection factory instead to create immutable Java collections ### What changes were proposed in this pull request? This pr change to use collection factory instread of `Collections.unmodifiable` to create an immutable Java collection(new collection API introduced after [JEP 269](https://openjdk.org/jeps/269)) ### Why are the changes needed? Make the relevant code look simple and clear. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43709 from LuciferYang/collection-factory. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../apache/spark/network/util/JavaUtils.java | 44 +++++++++---------- .../scala/org/apache/spark/FutureAction.scala | 5 +-- .../org/apache/spark/util/AccumulatorV2.scala | 3 +- .../SpecificParquetRecordReaderBase.java | 5 +-- .../apache/spark/streaming/JavaAPISuite.java | 4 +- 5 files changed, 26 insertions(+), 35 deletions(-) diff --git a/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java index bbe764b8366c8..fa0a2629f3502 100644 --- a/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -202,29 +202,27 @@ private static boolean isSymlink(File file) throws IOException { private static final Map byteSuffixes; static { - final Map timeSuffixesBuilder = new HashMap<>(); - timeSuffixesBuilder.put("us", TimeUnit.MICROSECONDS); - timeSuffixesBuilder.put("ms", TimeUnit.MILLISECONDS); - timeSuffixesBuilder.put("s", TimeUnit.SECONDS); - timeSuffixesBuilder.put("m", TimeUnit.MINUTES); - timeSuffixesBuilder.put("min", TimeUnit.MINUTES); - timeSuffixesBuilder.put("h", TimeUnit.HOURS); - timeSuffixesBuilder.put("d", TimeUnit.DAYS); - timeSuffixes = Collections.unmodifiableMap(timeSuffixesBuilder); - - final Map byteSuffixesBuilder = new HashMap<>(); - byteSuffixesBuilder.put("b", ByteUnit.BYTE); - byteSuffixesBuilder.put("k", ByteUnit.KiB); - byteSuffixesBuilder.put("kb", ByteUnit.KiB); - byteSuffixesBuilder.put("m", ByteUnit.MiB); - byteSuffixesBuilder.put("mb", ByteUnit.MiB); - byteSuffixesBuilder.put("g", ByteUnit.GiB); - byteSuffixesBuilder.put("gb", ByteUnit.GiB); - byteSuffixesBuilder.put("t", ByteUnit.TiB); - byteSuffixesBuilder.put("tb", ByteUnit.TiB); - byteSuffixesBuilder.put("p", ByteUnit.PiB); - byteSuffixesBuilder.put("pb", ByteUnit.PiB); - byteSuffixes = Collections.unmodifiableMap(byteSuffixesBuilder); + timeSuffixes = Map.of( + "us", TimeUnit.MICROSECONDS, + "ms", TimeUnit.MILLISECONDS, + "s", TimeUnit.SECONDS, + "m", TimeUnit.MINUTES, + "min", TimeUnit.MINUTES, + "h", TimeUnit.HOURS, + "d", TimeUnit.DAYS); + + byteSuffixes = Map.ofEntries( + Map.entry("b", ByteUnit.BYTE), + Map.entry("k", ByteUnit.KiB), + Map.entry("kb", ByteUnit.KiB), + Map.entry("m", ByteUnit.MiB), + Map.entry("mb", ByteUnit.MiB), + Map.entry("g", ByteUnit.GiB), + Map.entry("gb", ByteUnit.GiB), + Map.entry("t", ByteUnit.TiB), + Map.entry("tb", ByteUnit.TiB), + Map.entry("p", ByteUnit.PiB), + Map.entry("pb", ByteUnit.PiB)); } /** diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 9100d4ce041bf..a68700421b8df 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -17,7 +17,6 @@ package org.apache.spark -import java.util.Collections import java.util.concurrent.TimeUnit import scala.concurrent._ @@ -255,8 +254,6 @@ private[spark] class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) extends JavaFutureAction[T] { - import scala.jdk.CollectionConverters._ - override def isCancelled: Boolean = futureAction.isCancelled override def isDone: Boolean = { @@ -266,7 +263,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S } override def jobIds(): java.util.List[java.lang.Integer] = { - Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava) + java.util.List.of(futureAction.jobIds.map(Integer.valueOf): _*) } private def getImpl(timeout: Duration): T = { diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 181033c9d20c8..c6d8073a0c2fa 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.util import java.{lang => jl} import java.io.ObjectInputStream -import java.util.ArrayList import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong @@ -505,7 +504,7 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { } override def value: java.util.List[T] = this.synchronized { - java.util.Collections.unmodifiableList(new ArrayList[T](getOrCreate)) + java.util.List.copyOf(getOrCreate) } private[spark] def setValue(newValue: java.util.List[T]): Unit = this.synchronized { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 4f2b65f36120a..6d00048154a56 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -23,7 +23,6 @@ import java.lang.reflect.InvocationTargetException; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -238,9 +237,7 @@ public void close() throws IOException { private static Map> toSetMultiMap(Map map) { Map> setMultiMap = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { - Set set = new HashSet<>(); - set.add(entry.getValue()); - setMultiMap.put(entry.getKey(), Collections.unmodifiableSet(set)); + setMultiMap.put(entry.getKey(), Set.of(entry.getValue())); } return Collections.unmodifiableMap(setMultiMap); } diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index b1f743b921969..f8d961fa8dd8e 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -704,11 +704,11 @@ public static void assertOrderInvariantEquals( List> expected, List> actual) { List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); + expectedSets.add(Set.copyOf(list)); } List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); + actualSets.add(Set.copyOf(list)); } Assertions.assertEquals(expectedSets, actualSets); } From 24edc0ef5bee578de8eec3b032f993812e4303ea Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 9 Nov 2023 15:25:52 +0800 Subject: [PATCH 13/15] [SPARK-45752][SQL] Unreferenced CTE should all be checked by CheckAnalysis0 ### What changes were proposed in this pull request? This PR fixes an issue that if a CTE is referenced by a non-referenced CTE, then this CTE should also have ref count as 0 and goes through CheckAnalysis0. This will guarantee analyzer throw proper error message for problematic CTE which is not referenced. ### Why are the changes needed? To improve error message for non-referenced CTE case. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #43614 from amaliujia/cte_ref. Lead-authored-by: Rui Wang Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/CheckAnalysis.scala | 28 +++++++++++++++++-- .../org/apache/spark/sql/CTEInlineSuite.scala | 11 ++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index cebaee2cdec9c..29d60ae0f41e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -148,15 +148,39 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass, missingCol, orderedCandidates, a.origin) } + private def checkUnreferencedCTERelations( + cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + visited: mutable.Map[Long, Boolean], + cteId: Long): Unit = { + if (visited(cteId)) { + return + } + val (cteDef, _, refMap) = cteMap(cteId) + refMap.foreach { case (id, _) => + checkUnreferencedCTERelations(cteMap, visited, id) + } + checkAnalysis0(cteDef.child) + visited(cteId) = true + } + def checkAnalysis(plan: LogicalPlan): Unit = { val inlineCTE = InlineCTE(alwaysInline = true) val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] inlineCTE.buildCTEMap(plan, cteMap) - cteMap.values.foreach { case (relation, refCount, _) => + cteMap.values.foreach { case (relation, _, _) => // If a CTE relation is never used, it will disappear after inline. Here we explicitly check // analysis for it, to make sure the entire query plan is valid. try { - if (refCount == 0) checkAnalysis0(relation.child) + // If a CTE relation ref count is 0, the other CTE relations that reference it + // should also be checked by checkAnalysis0. This code will also guarantee the leaf + // relations that do not reference any others are checked first. + val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) + cteMap.foreach { case (cteId, _) => + val (_, refCount, _) = cteMap(cteId) + if (refCount == 0) { + checkUnreferencedCTERelations(cteMap, visited, cteId) + } + } } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, relation.child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 5f6c44792658a..055c04992c009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -678,6 +678,17 @@ abstract class CTEInlineSuiteBase }.isDefined, "CTE columns should not be pruned.") } } + + test("SPARK-45752: Unreferenced CTE should all be checked by CheckAnalysis0") { + val e = intercept[AnalysisException](sql( + s""" + |with + |a as (select * from non_exist), + |b as (select * from a) + |select 2 + |""".stripMargin)) + checkErrorTableNotFound(e, "`non_exist`", ExpectedContext("non_exist", 26, 34)) + } } class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite From c128f811820e5a31ddd5bd1c95ed8dd49017eaea Mon Sep 17 00:00:00 2001 From: xieshuaihu Date: Thu, 9 Nov 2023 15:56:40 +0800 Subject: [PATCH 14/15] [SPARK-45814][CONNECT][SQL] Make ArrowConverters.createEmptyArrowBatch call close() to avoid memory leak ### What changes were proposed in this pull request? Make `ArrowBatchIterator` implement `AutoCloseable` and `ArrowConverters.createEmptyArrowBatch()` call close() to avoid memory leak. ### Why are the changes needed? `ArrowConverters.createEmptyArrowBatch` don't call `super.hasNext`, if `TaskContext.get` returns `None`, then memory allocated in `ArrowBatchIterator` is leaked. In spark connect, `createEmptyArrowBatch` is called in [SparkConnectPlanner](https://github.com/apache/spark/blob/master/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L2558) and [SparkConnectPlanExecution](https://github.com/apache/spark/blob/master/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala#L224), which cause a long running driver consume all off-heap memory specified by `-XX:MaxDirectMemorySize`. This is the exception stack: ``` org.apache.arrow.memory.OutOfMemoryException: Failure allocating buffer. at io.netty.buffer.PooledByteBufAllocatorL.allocate(PooledByteBufAllocatorL.java:67) at org.apache.arrow.memory.NettyAllocationManager.(NettyAllocationManager.java:77) at org.apache.arrow.memory.NettyAllocationManager.(NettyAllocationManager.java:84) at org.apache.arrow.memory.NettyAllocationManager$1.create(NettyAllocationManager.java:34) at org.apache.arrow.memory.BaseAllocator.newAllocationManager(BaseAllocator.java:354) at org.apache.arrow.memory.BaseAllocator.newAllocationManager(BaseAllocator.java:349) at org.apache.arrow.memory.BaseAllocator.bufferWithoutReservation(BaseAllocator.java:337) at org.apache.arrow.memory.BaseAllocator.buffer(BaseAllocator.java:315) at org.apache.arrow.memory.BaseAllocator.buffer(BaseAllocator.java:279) at org.apache.arrow.vector.BaseValueVector.allocFixedDataAndValidityBufs(BaseValueVector.java:192) at org.apache.arrow.vector.BaseFixedWidthVector.allocateBytes(BaseFixedWidthVector.java:338) at org.apache.arrow.vector.BaseFixedWidthVector.allocateNew(BaseFixedWidthVector.java:308) at org.apache.arrow.vector.BaseFixedWidthVector.allocateNew(BaseFixedWidthVector.java:273) at org.apache.spark.sql.execution.arrow.ArrowWriter$.$anonfun$create$1(ArrowWriter.scala:44) at scala.collection.StrictOptimizedIterableOps.map(StrictOptimizedIterableOps.scala:100) at scala.collection.StrictOptimizedIterableOps.map$(StrictOptimizedIterableOps.scala:87) at scala.collection.convert.JavaCollectionWrappers$JListWrapper.map(JavaCollectionWrappers.scala:103) at org.apache.spark.sql.execution.arrow.ArrowWriter$.create(ArrowWriter.scala:43) at org.apache.spark.sql.execution.arrow.ArrowConverters$ArrowBatchIterator.(ArrowConverters.scala:93) at org.apache.spark.sql.execution.arrow.ArrowConverters$ArrowBatchWithSchemaIterator.(ArrowConverters.scala:138) at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.(ArrowConverters.scala:231) at org.apache.spark.sql.execution.arrow.ArrowConverters$.createEmptyArrowBatch(ArrowConverters.scala:229) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleSqlCommand(SparkConnectPlanner.scala:2481) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2426) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:189) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:189) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:176) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:178) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:175) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:188) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:228) Caused by: io.netty.util.internal.OutOfDirectMemoryError: failed to allocate 4194304 byte(s) of direct memory (used: 1069547799, max: 1073741824) at io.netty.util.internal.PlatformDependent.incrementMemoryCounter(PlatformDependent.java:845) at io.netty.util.internal.PlatformDependent.allocateDirectNoCleaner(PlatformDependent.java:774) at io.netty.buffer.PoolArena$DirectArena.allocateDirect(PoolArena.java:721) at io.netty.buffer.PoolArena$DirectArena.newChunk(PoolArena.java:696) at io.netty.buffer.PoolArena.allocateNormal(PoolArena.java:215) at io.netty.buffer.PoolArena.tcacheAllocateSmall(PoolArena.java:180) at io.netty.buffer.PoolArena.allocate(PoolArena.java:137) at io.netty.buffer.PoolArena.allocate(PoolArena.java:129) at io.netty.buffer.PooledByteBufAllocatorL$InnerAllocator.newDirectBufferL(PooledByteBufAllocatorL.java:181) at io.netty.buffer.PooledByteBufAllocatorL$InnerAllocator.directBuffer(PooledByteBufAllocatorL.java:214) at io.netty.buffer.PooledByteBufAllocatorL.allocate(PooledByteBufAllocatorL.java:58) ... 37 more ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43691 from xieshuaihu/spark-45814. Authored-by: xieshuaihu Signed-off-by: yangjie01 --- .../sql/execution/arrow/ArrowConverters.scala | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index d6bf1e29edddd..9ddec74374abd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -80,7 +80,7 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean, - context: TaskContext) extends Iterator[Array[Byte]] { + context: TaskContext) extends Iterator[Array[Byte]] with AutoCloseable { protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames) @@ -93,13 +93,11 @@ private[sql] object ArrowConverters extends Logging { protected val arrowWriter = ArrowWriter.create(root) Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() + close() }} override def hasNext: Boolean = rowIter.hasNext || { - root.close() - allocator.close() + close() false } @@ -124,6 +122,11 @@ private[sql] object ArrowConverters extends Logging { out.toByteArray } + + override def close(): Unit = { + root.close() + allocator.close() + } } private[sql] class ArrowBatchWithSchemaIterator( @@ -226,11 +229,19 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, timeZoneId: String, errorOnDuplicatedFieldNames: Boolean): Array[Byte] = { - new ArrowBatchWithSchemaIterator( + val batches = new ArrowBatchWithSchemaIterator( Iterator.empty, schema, 0L, 0L, timeZoneId, errorOnDuplicatedFieldNames, TaskContext.get()) { override def hasNext: Boolean = true - }.next() + } + Utils.tryWithSafeFinally { + batches.next() + } { + // If taskContext is null, `batches.close()` should be called to avoid memory leak. + if (TaskContext.get() == null) { + batches.close() + } + } } /** From 06d8cbe073499ff16bca3165e2de1192daad3984 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 9 Nov 2023 16:23:38 +0800 Subject: [PATCH 15/15] [SPARK-45847][SQL][TESTS] CliSuite flakiness due to non-sequential guarantee for stdout&stderr ### What changes were proposed in this pull request? In CliSuite, This PR adds a retry for tests that write errors to STDERR. ### Why are the changes needed? To fix flakiness tests as below https://github.com/chenhao-db/apache-spark/actions/runs/6791437199/job/18463313766 https://github.com/dongjoon-hyun/spark/actions/runs/6753670527/job/18361206900 ```sql [info] Spark master: local, Application Id: local-1699402393189 [info] spark-sql> /* SELECT /*+ HINT() 4; */; [info] [info] [PARSE_SYNTAX_ERROR] Syntax error at or near ';'. SQLSTATE: 42601 (line 1, pos 26) [info] [info] == SQL == [info] /* SELECT /*+ HINT() 4; */; [info] --------------------------^^^ [info] [info] spark-sql> /* SELECT /*+ HINT() 4; */ SELECT 1; [info] 1 [info] Time taken: 1.499 seconds, Fetched 1 row(s) [info] [info] [UNCLOSED_BRACKETED_COMMENT] Found an unclosed bracketed comment. Please, append */ at the end of the comment. SQLSTATE: 42601 [info] == SQL == [info] /* Here is a unclosed bracketed comment SELECT 1; [info] spark-sql> /* Here is a unclosed bracketed comment SELECT 1; [info] spark-sql> /* SELECT /*+ HINT() */ 4; */; [info] spark-sql> ``` As you can see the fragment above, the query on the 3rd line from the bottom, came from STDOUT, was printed later than its error output, came from STDERR. In this scenario, the error output would not match anything and would simply go unnoticed. Finally, timed out and failed. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests and CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #43725 from yaooqinn/SPARK-45847. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/hive/thriftserver/CliSuite.scala | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 5391965ded2e9..4f0d4dff566c4 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -383,7 +383,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-11188 Analysis error reporting") { + testRetry("SPARK-11188 Analysis error reporting") { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" -> "nonexistent_table" @@ -551,7 +551,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SparkException with root cause will be printStacktrace") { + testRetry("SparkException with root cause will be printStacktrace") { // If it is not in silent mode, will print the stacktrace runCliWithin( 1.minute, @@ -575,8 +575,8 @@ class CliSuite extends SparkFunSuite { runCliWithin(1.minute)("SELECT MAKE_DATE(-44, 3, 15);" -> "-0044-03-15") } - test("SPARK-33100: Ignore a semicolon inside a bracketed comment in spark-sql") { - runCliWithin(4.minute)( + testRetry("SPARK-33100: Ignore a semicolon inside a bracketed comment in spark-sql") { + runCliWithin(1.minute)( "/* SELECT 'test';*/ SELECT 'test';" -> "test", ";;/* SELECT 'test';*/ SELECT 'test';" -> "test", "/* SELECT 'test';*/;; SELECT 'test';" -> "test", @@ -623,8 +623,8 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-37555: spark-sql should pass last unclosed comment to backend") { - runCliWithin(5.minute)( + testRetry("SPARK-37555: spark-sql should pass last unclosed comment to backend") { + runCliWithin(1.minute)( // Only unclosed comment. "/* SELECT /*+ HINT() 4; */;".stripMargin -> "Syntax error at or near ';'", // Unclosed nested bracketed comment. @@ -637,7 +637,7 @@ class CliSuite extends SparkFunSuite { ) } - test("SPARK-37694: delete [jar|file|archive] shall use spark sql processor") { + testRetry("SPARK-37694: delete [jar|file|archive] shall use spark sql processor") { runCliWithin(2.minute, errorResponses = Seq("ParseException"))( "delete jar dummy.jar;" -> "Syntax error at or near 'jar': missing 'FROM'. SQLSTATE: 42601 (line 1, pos 7)") @@ -679,7 +679,7 @@ class CliSuite extends SparkFunSuite { SparkSQLEnv.stop() } - test("SPARK-39068: support in-memory catalog and running concurrently") { + testRetry("SPARK-39068: support in-memory catalog and running concurrently") { val extraConf = Seq("-c", s"${StaticSQLConf.CATALOG_IMPLEMENTATION.key}=in-memory") val cd = new CountDownLatch(2) def t: Thread = new Thread { @@ -699,7 +699,7 @@ class CliSuite extends SparkFunSuite { } // scalastyle:off line.size.limit - test("formats of error messages") { + testRetry("formats of error messages") { def check(format: ErrorMessageFormat.Value, errorMessage: String, silent: Boolean): Unit = { val expected = errorMessage.split(System.lineSeparator()).map("" -> _) runCliWithin( @@ -811,7 +811,6 @@ class CliSuite extends SparkFunSuite { s"spark.sql.catalog.$catalogName.url=jdbc:derby:memory:$catalogName;create=true" val catalogDriver = s"spark.sql.catalog.$catalogName.driver=org.apache.derby.jdbc.AutoloadedDriver" - val database = s"-database $catalogName.SYS" val catalogConfigs = Seq(catalogImpl, catalogDriver, catalogUrl, "spark.sql.catalogImplementation=in-memory") .flatMap(Seq("--conf", _))