From 0a7023f55ddc43ab9f0802b84b248fbc646a91aa Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 31 Oct 2023 22:01:38 -0700 Subject: [PATCH 01/13] [SPARK-45734][BUILD] Upgrade commons-io to 2.15.0 ### What changes were proposed in this pull request? This pr upgrade commons-io from 2.14.0 to 2.15.0 ### Why are the changes needed? The updates of `commons-io` 2.15.0 mainly focus on fixing bugs in file and stream handling, adding new file and stream handling features, and optimizing the performance of file content comparison: 1. Bug fixes: This version fixes multiple bugs, mainly in file and stream handling. For example, it fixes the encoding matching issue of `XmlStreamReader` (IO-810), the issue that `FileUtils.listFiles` and `FileUtils.iterateFiles` methods failed to close their internal streams (IO-811), and the issue that `StreamIterator` failed to close its internal stream (IO-811). In addition, it also fixes the null pointer exception information of `RandomAccessFileMode.create(Path)`, and the issue that `UnsynchronizedBufferedInputStream.read(byte[], int, int)` does not use the buffer (IO-816). 2. New features: This version adds some new classes and methods, such as `org.apache.commons.io.channels.FileChannels`, `RandomAccessFiles#contentEquals(RandomAccessFile, RandomAccessFile)`, `RandomAccessFiles#reset(RandomAccessFile)`, and `org.apache.commons.io.StreamIterator`. In addition, it also added `MessageDigestInputStream` and deprecated `MessageDigestCalculatingInputStream`. 3. Performance optimization: This version optimizes the performance of `PathUtils.fileContentEquals(Path, Path, LinkOption[], OpenOption[])`, `PathUtils.fileContentEquals(Path, Path)`, and `FileUtils.contentEquals(File, File)`. From the release notes, the related performance has improved by about 60%. The full release notes as follow: - https://commons.apache.org/proper/commons-io/changes-report.html#a2.15.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43592 from LuciferYang/commons-io-215. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index b5e345ecf0eb8..6fa0f738cf120 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -43,7 +43,7 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.24.0//commons-compress-1.24.0.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.14.0//commons-io-2.14.0.jar +commons-io/2.15.0//commons-io-2.15.0.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.13.0//commons-lang3-3.13.0.jar commons-logging/1.1.3//commons-logging-1.1.3.jar diff --git a/pom.xml b/pom.xml index 9f550456cbcbe..d545c74392800 100644 --- a/pom.xml +++ b/pom.xml @@ -192,7 +192,7 @@ 3.0.3 1.16.0 1.24.0 - 2.14.0 + 2.15.0 2.6 From e1bc48b729e40390a4b0f977eec4a9050c7cac77 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 31 Oct 2023 22:02:39 -0700 Subject: [PATCH 02/13] [SPARK-45704][BUILD] Fix compile warning - using symbols inherited from a superclass shadow symbols defined in an outer scope ### What changes were proposed in this pull request? After upgrade to scala 2.13, when using symbols inherited from a superclass shadow symbols defined in an outer scope, the following warning will appear: ``` [error] /Users/panbingkun/Developer/spark/spark-community/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala:1315:39: reference to child is ambiguous; [error] it is both defined in the enclosing method apply and inherited in the enclosing anonymous class as value child (defined in class IsNull) [error] In Scala 2, symbols inherited from a superclass shadow symbols defined in an outer scope. [error] Such references are ambiguous in Scala 3. To continue using the inherited symbol, write `this.child`. [error] Or use `-Wconf:msg=legacy-binding:s` to silence this warning. [quickfixable] [error] Applicable -Wconf / nowarn filters for this fatal warning: msg=, cat=other, site=org.apache.spark.sql.catalyst.expressions.IsUnknown.apply [error] override def sql: String = s"(${child.sql} IS UNKNOWN)" [error] ^ ``` The pr aims to fix it. ### Why are the changes needed? Prepare for upgrading to scala 3. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA - Manually test: ``` build/sbt -Phadoop-3 -Pdocker-integration-tests -Pspark-ganglia-lgpl -Pkinesis-asl -Pkubernetes -Phive-thriftserver -Pconnect -Pyarn -Phive -Phadoop-cloud -Pvolcano -Pkubernetes-integration-tests Test/package streaming-kinesis-asl-assembly/assembly connect/assembly ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43593 from panbingkun/SPARK-45704. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .../deploy/client/StandaloneAppClient.scala | 6 +++--- .../CoarseGrainedSchedulerBackend.scala | 2 +- .../spark/storage/DiskBlockObjectWriter.scala | 2 +- .../CoarseGrainedExecutorBackendSuite.scala | 20 +++++++++---------- pom.xml | 7 ------- project/SparkBuild.scala | 5 ----- .../sql/catalyst/expressions/predicates.scala | 4 ++-- .../connector/catalog/InMemoryBaseTable.scala | 2 +- .../parquet/ParquetRowConverter.scala | 18 ++++++++--------- .../spark/sql/execution/python/RowQueue.scala | 2 +- .../internal/BaseSessionStateBuilder.scala | 2 +- .../command/AlignAssignmentsSuiteBase.scala | 2 +- .../command/PlanResolutionSuite.scala | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- 14 files changed, 32 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index a7e4c1fbab295..b0ee6018970ab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -90,7 +90,7 @@ private[spark] class StandaloneAppClient( case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - stop() + this.stop() } } @@ -168,7 +168,7 @@ private[spark] class StandaloneAppClient( case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - stop() + this.stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = s"$appId/$id" @@ -203,7 +203,7 @@ private[spark] class StandaloneAppClient( markDead("Application has been stopped.") sendToMaster(UnregisterApplication(appId.get)) context.reply(true) - stop() + this.stop() case r: RequestExecutors => master match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c49b2411e7635..e02dd27937062 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -319,7 +319,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case StopDriver => context.reply(true) - stop() + this.stop() case UpdateExecutorsLogLevel(logLevel) => currentLogLevel = Some(logLevel) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index f8bd73e65617f..2096da2fca02d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -63,7 +63,7 @@ private[spark] class DiskBlockObjectWriter( */ private trait ManualCloseOutputStream extends OutputStream { abstract override def close(): Unit = { - flush() + this.flush() } def manualClose(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 3ef4da6d3d3f1..28af0656869b3 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -326,11 +326,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { new executor.TaskRunner(backend, taskDescription, None) { override def run(): Unit = { - logInfo(s"task ${taskDescription.taskId} runs.") + logInfo(s"task ${this.taskDescription.taskId} runs.") } override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") + logInfo(s"task ${this.taskDescription.taskId} killed.") } } } @@ -434,13 +434,13 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { new executor.TaskRunner(backend, taskDescription, None) { override def run(): Unit = { - tasksExecuted.put(taskDescription.taskId, true) - logInfo(s"task ${taskDescription.taskId} runs.") + tasksExecuted.put(this.taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} runs.") } override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") - tasksKilled.put(taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} killed.") + tasksKilled.put(this.taskDescription.taskId, true) } } } @@ -523,13 +523,13 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { new executor.TaskRunner(backend, taskDescription, None) { override def run(): Unit = { - tasksExecuted.put(taskDescription.taskId, true) - logInfo(s"task ${taskDescription.taskId} runs.") + tasksExecuted.put(this.taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} runs.") } override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") - tasksKilled.put(taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} killed.") + tasksKilled.put(this.taskDescription.taskId, true) } } } diff --git a/pom.xml b/pom.xml index d545c74392800..e29d81f6887c5 100644 --- a/pom.xml +++ b/pom.xml @@ -2985,13 +2985,6 @@ SPARK-40497 Upgrade Scala to 2.13.11 and suppress `Implicit definition should have explicit type` --> -Wconf:msg=Implicit definition should have explicit type:s - - -Wconf:msg=legacy-binding:s diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d9d4a836ab5d4..d76af6a06cfd5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -249,11 +249,6 @@ object SparkBuild extends PomBuild { "-Wconf:cat=deprecation&msg=procedure syntax is deprecated:e", // SPARK-40497 Upgrade Scala to 2.13.11 and suppress `Implicit definition should have explicit type` "-Wconf:msg=Implicit definition should have explicit type:s", - // SPARK-45331 Upgrade Scala to 2.13.12 and suppress "In Scala 2, symbols inherited - // from a superclass shadow symbols defined in an outer scope. Such references are - // ambiguous in Scala 3. To continue using the inherited symbol, write `this.stop`. - // Or use `-Wconf:msg=legacy-binding:s` to silence this warning. [quickfixable]" - "-Wconf:msg=legacy-binding:s", // SPARK-45627 Symbol literals are deprecated in Scala 2.13 and it's a compile error in Scala 3. "-Wconf:cat=deprecation&msg=symbol literal is deprecated:e", // SPARK-45627 `enum`, `export` and `given` will become keywords in Scala 3, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9eefcef8e17d2..761bd3f33586e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -1312,7 +1312,7 @@ object IsUnknown { def apply(child: Expression): Predicate = { new IsNull(child) with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def sql: String = s"(${child.sql} IS UNKNOWN)" + override def sql: String = s"(${this.child.sql} IS UNKNOWN)" } } } @@ -1321,7 +1321,7 @@ object IsNotUnknown { def apply(child: Expression): Predicate = { new IsNotNull(child) with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def sql: String = s"(${child.sql} IS NOT UNKNOWN)" + override def sql: String = s"(${this.child.sql} IS NOT UNKNOWN)" } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 7765bc26741be..cd7f7295d5cb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -452,7 +452,7 @@ abstract class InMemoryBaseTable( val matchingKeys = values.map { value => if (value != null) value.toString else null }.toSet - data = data.filter(partition => { + this.data = this.data.filter(partition => { val rows = partition.asInstanceOf[BufferedRows] rows.key match { // null partitions are represented as Seq(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 936339e091d8f..89c7cae175aed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -311,7 +311,7 @@ private[parquet] class ParquetRowConverter( case LongType if isUnsignedIntTypeMatched(32) => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = - updater.setLong(Integer.toUnsignedLong(value)) + this.updater.setLong(Integer.toUnsignedLong(value)) } case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: AnsiIntervalType => @@ -320,13 +320,13 @@ private[parquet] class ParquetRowConverter( case ByteType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = - updater.setByte(value.asInstanceOf[PhysicalByteType#InternalType]) + this.updater.setByte(value.asInstanceOf[PhysicalByteType#InternalType]) } case ShortType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = - updater.setShort(value.asInstanceOf[PhysicalShortType#InternalType]) + this.updater.setShort(value.asInstanceOf[PhysicalShortType#InternalType]) } // For INT32 backed decimals @@ -346,7 +346,7 @@ private[parquet] class ParquetRowConverter( case _: DecimalType if isUnsignedIntTypeMatched(64) => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.set(Decimal(java.lang.Long.toUnsignedString(value))) + this.updater.set(Decimal(java.lang.Long.toUnsignedString(value))) } } @@ -391,7 +391,7 @@ private[parquet] class ParquetRowConverter( .asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MICROS => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.setLong(timestampRebaseFunc(value)) + this.updater.setLong(timestampRebaseFunc(value)) } } @@ -404,7 +404,7 @@ private[parquet] class ParquetRowConverter( new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(timestampRebaseFunc(micros)) + this.updater.setLong(timestampRebaseFunc(micros)) } } @@ -417,7 +417,7 @@ private[parquet] class ParquetRowConverter( val gregorianMicros = int96RebaseFunc(julianMicros) val adjTime = convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) - updater.setLong(adjTime) + this.updater.setLong(adjTime) } } @@ -434,14 +434,14 @@ private[parquet] class ParquetRowConverter( new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(micros) + this.updater.setLong(micros) } } case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { - updater.set(dateRebaseFunc(value)) + this.updater.set(dateRebaseFunc(value)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index 0e3243eac6230..5e0c5ff92fdab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -233,7 +233,7 @@ private[python] case class HybridRowQueue( val buffer = if (page != null) { new InMemoryRowQueue(page, numFields) { override def close(): Unit = { - freePage(page) + freePage(this.page) } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 5543b409d1702..1d496b027ef5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -188,7 +188,7 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - new ResolveSessionCatalog(catalogManager) +: + new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: customResolutionRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index 6f9cc66f24769..2bc747c0abee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -191,7 +191,7 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { protected def parseAndResolve(query: String): LogicalPlan = { val analyzer = new CustomAnalyzer(catalogManager) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( - new ResolveSessionCatalog(catalogManager)) + new ResolveSessionCatalog(this.catalogManager)) } val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query)) analyzer.checkAnalysis(analyzed) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 4eb65305de838..e39cc91d5f048 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -233,7 +233,7 @@ class PlanResolutionSuite extends AnalysisTest { } val analyzer = new Analyzer(catalogManager) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( - new ResolveSessionCatalog(catalogManager)) + new ResolveSessionCatalog(this.catalogManager)) } // We don't check analysis here by default, as we expect the plan to be unresolved // such as `CreateTable`. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2d0bcdff07151..0b5e98d0a3e40 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -123,7 +123,7 @@ class HiveSessionStateBuilder( */ override protected def planner: SparkPlanner = { new SparkPlanner(session, experimentalMethods) with HiveStrategies { - override val sparkSession: SparkSession = session + override val sparkSession: SparkSession = this.session override def extraPlanningStrategies: Seq[Strategy] = super.extraPlanningStrategies ++ customPlanningStrategies ++ From feea99e9d8c18877875f3b8cae2ffc4a7e9f0f7c Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 1 Nov 2023 10:39:16 +0300 Subject: [PATCH 03/13] [SPARK-45022][SQL] Provide context for dataset API errors ### What changes were proposed in this pull request? This PR captures the dataset APIs used by the user code and the call site in the user code and provides better error messages. E.g. consider the following Spark app `SimpleApp.scala`: ```scala 1 import org.apache.spark.sql.SparkSession 2 import org.apache.spark.sql.functions._ 3 4 object SimpleApp { 5 def main(args: Array[String]) { 6 val spark = SparkSession.builder.appName("Simple Application").config("spark.sql.ansi.enabled", true).getOrCreate() 7 import spark.implicits._ 8 9 val c = col("a") / col("b") 10 11 Seq((1, 0)).toDF("a", "b").select(c).show() 12 13 spark.stop() 14 } 15 } ``` After this PR the error message contains the error context (which Spark Dataset API is called from where in the user code) in the following form: ``` Exception in thread "main" org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. == Dataset == "div" was called from SimpleApp$.main(SimpleApp.scala:9) at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:201) at org.apache.spark.sql.catalyst.expressions.DivModLike.eval(arithmetic.scala:672 ... ``` which is similar to the already provided context in case of SQL queries: ``` org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error. == SQL(line 1, position 1) == a / b ^^^^^ at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:201) at org.apache.spark.sql.errors.QueryExecutionErrors.divideByZeroError(QueryExecutionErrors.scala) ... ``` Please note that stack trace in `spark-shell` doesn't contain meaningful elements: ``` scala> Thread.currentThread().getStackTrace.foreach(println) java.base/java.lang.Thread.getStackTrace(Thread.java:1602) $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:23) $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:27) $line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:29) $line15.$read$$iw$$iw$$iw$$iw$$iw.(:31) $line15.$read$$iw$$iw$$iw$$iw.(:33) $line15.$read$$iw$$iw$$iw.(:35) $line15.$read$$iw$$iw.(:37) $line15.$read$$iw.(:39) $line15.$read.(:41) $line15.$read$.(:45) $line15.$read$.() $line15.$eval$.$print$lzycompute(:7) $line15.$eval$.$print(:6) $line15.$eval.$print() java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) ... ``` so this change doesn't help with that usecase. ### Why are the changes needed? To provide more user friendly errors. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Added new UTs to `QueryExecutionAnsiErrorsSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43334 from MaxGekk/context-for-dataset-api-errors. Lead-authored-by: Max Gekk Co-authored-by: Peter Toth Signed-off-by: Max Gekk --- .../java/org/apache/spark/QueryContext.java | 9 + .../org/apache/spark/QueryContextType.java | 31 ++ .../apache/spark/SparkThrowableHelper.scala | 20 +- .../main/protobuf/spark/connect/base.proto | 13 + .../client/GrpcExceptionConverter.scala | 13 +- .../org/apache/spark/SparkFunSuite.scala | 60 ++- .../apache/spark/SparkThrowableSuite.scala | 51 +++ project/MimaExcludes.scala | 9 +- python/pyspark/sql/connect/proto/base_pb2.py | 24 +- python/pyspark/sql/connect/proto/base_pb2.pyi | 37 ++ .../spark/sql/catalyst/parser/parsers.scala | 7 +- ...QueryContext.scala => QueryContexts.scala} | 46 ++- .../spark/sql/catalyst/trees/origin.scala | 20 +- .../spark/sql/catalyst/util/MathUtils.scala | 16 +- .../catalyst/util/SparkDateTimeUtils.scala | 6 +- .../spark/sql/errors/DataTypeErrors.scala | 12 +- .../spark/sql/errors/DataTypeErrorsBase.scala | 7 +- .../spark/sql/errors/ExecutionErrors.scala | 11 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../spark/sql/catalyst/expressions/Cast.scala | 8 +- .../sql/catalyst/expressions/Expression.scala | 10 +- .../expressions/aggregate/Average.scala | 5 +- .../catalyst/expressions/aggregate/Sum.scala | 5 +- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/collectionOperations.scala | 7 +- .../expressions/complexTypeExtractors.scala | 4 +- .../expressions/decimalExpressions.scala | 12 +- .../expressions/higherOrderFunctions.scala | 10 +- .../expressions/intervalExpressions.scala | 6 +- .../expressions/mathExpressions.scala | 6 +- .../expressions/stringExpressions.scala | 5 +- .../sql/catalyst/util/DateTimeUtils.scala | 6 +- .../sql/catalyst/util/NumberConverter.scala | 6 +- .../sql/catalyst/util/UTF8StringUtils.scala | 12 +- .../sql/errors/QueryExecutionErrors.scala | 30 +- .../sql/catalyst/analysis/AnalysisTest.scala | 4 +- .../analysis/V2WriteAnalysisSuite.scala | 5 +- .../scala/org/apache/spark/sql/Column.scala | 28 +- .../spark/sql/DataFrameStatFunctions.scala | 28 +- .../scala/org/apache/spark/sql/Dataset.scala | 356 ++++++++++-------- .../apache/spark/sql/execution/subquery.scala | 5 +- .../org/apache/spark/sql/functions.scala | 78 ++-- .../scala/org/apache/spark/sql/package.scala | 40 ++ .../spark/sql/ColumnExpressionSuite.scala | 48 ++- .../apache/spark/sql/CsvFunctionsSuite.scala | 9 +- .../spark/sql/DataFrameAggregateSuite.scala | 11 +- .../spark/sql/DataFrameFunctionsSuite.scala | 205 +++++++--- .../spark/sql/DataFramePivotSuite.scala | 3 +- .../spark/sql/DataFrameSelfJoinSuite.scala | 3 +- .../sql/DataFrameSetOperationsSuite.scala | 5 +- .../apache/spark/sql/DataFrameStatSuite.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- .../sql/DataFrameWindowFramesSuite.scala | 27 +- .../sql/DataFrameWindowFunctionsSuite.scala | 5 +- .../org/apache/spark/sql/DatasetSuite.scala | 31 +- .../spark/sql/DatasetUnpivotSuite.scala | 9 +- .../spark/sql/GeneratorFunctionSuite.scala | 15 +- .../apache/spark/sql/JsonFunctionsSuite.scala | 11 +- .../apache/spark/sql/ParametersSuite.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 12 + .../org/apache/spark/sql/SQLQuerySuite.scala | 5 +- .../spark/sql/StringFunctionsSuite.scala | 10 +- .../errors/QueryCompilationErrorsSuite.scala | 11 +- .../QueryExecutionAnsiErrorsSuite.scala | 79 ++++ .../spark/sql/execution/SQLViewSuite.scala | 4 +- .../datasources/FileMetadataStructSuite.scala | 14 +- .../execution/datasources/csv/CSVSuite.scala | 4 +- .../datasources/json/JsonSuite.scala | 4 +- ...rquetFileMetadataStructRowIndexSuite.scala | 3 +- .../spark/sql/sources/InsertSuite.scala | 4 +- .../spark/sql/streaming/StreamSuite.scala | 4 +- 71 files changed, 1163 insertions(+), 471 deletions(-) create mode 100644 common/utils/src/main/java/org/apache/spark/QueryContextType.java rename sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/{SQLQueryContext.scala => QueryContexts.scala} (78%) diff --git a/common/utils/src/main/java/org/apache/spark/QueryContext.java b/common/utils/src/main/java/org/apache/spark/QueryContext.java index de5b29d02951d..45e38c8cfe0f3 100644 --- a/common/utils/src/main/java/org/apache/spark/QueryContext.java +++ b/common/utils/src/main/java/org/apache/spark/QueryContext.java @@ -27,6 +27,9 @@ */ @Evolving public interface QueryContext { + // The type of this query context. + QueryContextType contextType(); + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -45,4 +48,10 @@ public interface QueryContext { // The corresponding fragment of the query which throws the exception. String fragment(); + + // The user code (call site of the API) that caused throwing the exception. + String callSite(); + + // Summary of the exception cause. + String summary(); } diff --git a/common/utils/src/main/java/org/apache/spark/QueryContextType.java b/common/utils/src/main/java/org/apache/spark/QueryContextType.java new file mode 100644 index 0000000000000..171833162bafa --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/QueryContextType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.Evolving; + +/** + * The type of {@link QueryContext}. + * + * @since 4.0.0 + */ +@Evolving +public enum QueryContextType { + SQL, + DataFrame +} diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index b312a1a7e2278..a44d36ff85b55 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -114,13 +114,19 @@ private[spark] object SparkThrowableHelper { g.writeArrayFieldStart("queryContext") e.getQueryContext.foreach { c => g.writeStartObject() - g.writeStringField("objectType", c.objectType()) - g.writeStringField("objectName", c.objectName()) - val startIndex = c.startIndex() + 1 - if (startIndex > 0) g.writeNumberField("startIndex", startIndex) - val stopIndex = c.stopIndex() + 1 - if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) - g.writeStringField("fragment", c.fragment()) + c.contextType() match { + case QueryContextType.SQL => + g.writeStringField("objectType", c.objectType()) + g.writeStringField("objectName", c.objectName()) + val startIndex = c.startIndex() + 1 + if (startIndex > 0) g.writeNumberField("startIndex", startIndex) + val stopIndex = c.stopIndex() + 1 + if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) + g.writeStringField("fragment", c.fragment()) + case QueryContextType.DataFrame => + g.writeStringField("fragment", c.fragment()) + g.writeStringField("callSite", c.callSite()) + } g.writeEndObject() } g.writeEndArray() diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 5b94c6d663cca..27f51551ba921 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -823,6 +823,13 @@ message FetchErrorDetailsResponse { // QueryContext defines the schema for the query context of a SparkThrowable. // It helps users understand where the error occurs while executing queries. message QueryContext { + // The type of this query context. + enum ContextType { + SQL = 0; + DATAFRAME = 1; + } + ContextType context_type = 10; + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -841,6 +848,12 @@ message FetchErrorDetailsResponse { // The corresponding fragment of the query which throws the exception. string fragment = 5; + + // The user code (call site of the API) that caused throwing the exception. + string callSite = 6; + + // Summary of the exception cause. + string summary = 7; } // SparkThrowable defines the schema for SparkThrowable exceptions. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index b2782442f4a53..3e53722caeb07 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -28,7 +28,7 @@ import io.grpc.protobuf.StatusProto import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods -import org.apache.spark.{QueryContext, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub import org.apache.spark.internal.Logging @@ -324,15 +324,18 @@ private[client] object GrpcExceptionConverter { val queryContext = error.getSparkThrowable.getQueryContextsList.asScala.map { queryCtx => new QueryContext { + override def contextType(): QueryContextType = queryCtx.getContextType match { + case FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME => + QueryContextType.DataFrame + case _ => QueryContextType.SQL + } override def objectType(): String = queryCtx.getObjectType - override def objectName(): String = queryCtx.getObjectName - override def startIndex(): Int = queryCtx.getStartIndex - override def stopIndex(): Int = queryCtx.getStopIndex - override def fragment(): String = queryCtx.getFragment + override def callSite(): String = queryCtx.getCallSite + override def summary(): String = queryCtx.getSummary } }.toArray diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index e3792eb0d237b..518c0592488fc 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -342,7 +342,7 @@ abstract class SparkFunSuite sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, - queryContext: Array[QueryContext] = Array.empty): Unit = { + queryContext: Array[ExpectedContext] = Array.empty): Unit = { assert(exception.getErrorClass === errorClass) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala @@ -364,16 +364,25 @@ abstract class SparkFunSuite val actualQueryContext = exception.getQueryContext() assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context") actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert(actual.objectType() === expected.objectType(), - "Invalid objectType of a query context Actual:" + actual.toString) - assert(actual.objectName() === expected.objectName(), - "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex(), - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex(), - "Invalid stopIndex of a query context. Actual:" + actual.toString) - assert(actual.fragment() === expected.fragment(), - "Invalid fragment of a query context. Actual:" + actual.toString) + assert(actual.contextType() === expected.contextType, + "Invalid contextType of a query context Actual:" + actual.toString) + if (actual.contextType() == QueryContextType.SQL) { + assert(actual.objectType() === expected.objectType, + "Invalid objectType of a query context Actual:" + actual.toString) + assert(actual.objectName() === expected.objectName, + "Invalid objectName of a query context. Actual:" + actual.toString) + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + assert(actual.fragment() === expected.fragment, + "Invalid fragment of a query context. Actual:" + actual.toString) + } else if (actual.contextType() == QueryContextType.DataFrame) { + assert(actual.fragment() === expected.fragment, + "Invalid code fragment of a query context. Actual:" + actual.toString) + assert(actual.callSite().matches(expected.callSitePattern), + "Invalid callSite of a query context. Actual:" + actual.toString) + } } } @@ -389,21 +398,21 @@ abstract class SparkFunSuite errorClass: String, sqlState: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, sqlState: String, - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, Map.empty, false, Array(context)) protected def checkError( @@ -411,7 +420,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, false, Array(context)) @@ -426,7 +435,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, matchPVals = true, Array(context)) @@ -453,16 +462,33 @@ abstract class SparkFunSuite parameters = Map("relationName" -> tableName)) case class ExpectedContext( + contextType: QueryContextType, objectType: String, objectName: String, startIndex: Int, stopIndex: Int, - fragment: String) extends QueryContext + fragment: String, + callSitePattern: String + ) object ExpectedContext { def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { ExpectedContext("", "", start, stop, fragment) } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } } class LogAppender(msg: String = "", maxEvents: Int = 1000) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 9f32d81f1ae3d..0206205c353a1 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -503,11 +503,14 @@ class SparkThrowableSuite extends SparkFunSuite { test("Get message in the specified format") { import ErrorMessageFormat._ class TestQueryContext extends QueryContext { + override val contextType = QueryContextType.SQL override val objectName = "v1" override val objectType = "VIEW" override val startIndex = 2 override val stopIndex = -1 override val fragment = "1 / 0" + override def callSite: String = throw new UnsupportedOperationException + override val summary = "" } val e = new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", @@ -577,6 +580,54 @@ class SparkThrowableSuite extends SparkFunSuite { | "message" : "Test message" | } |}""".stripMargin) + + class TestQueryContext2 extends QueryContext { + override val contextType = QueryContextType.DataFrame + override def objectName: String = throw new UnsupportedOperationException + override def objectType: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override val fragment: String = "div" + override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)" + override val summary = "" + } + val e4 = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array(new TestQueryContext2), + summary = "Query summary") + + assert(SparkThrowableHelper.getMessage(e4, PRETTY) === + "[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " + + "and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." + + " SQLSTATE: 22012\nQuery summary") + // scalastyle:off line.size.limit + assert(SparkThrowableHelper.getMessage(e4, MINIMAL) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "fragment" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + assert(SparkThrowableHelper.getMessage(e4, STANDARD) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set to \"false\" to bypass this error.", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "fragment" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + // scalastyle:on line.size.limit } test("overwrite error classes") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 10864390e3fc7..c0275e162722a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,14 @@ object MimaExcludes { // [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and SparkTransportConf ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"), // [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$"), + // [SPARK-45022][SQL] Provide context for dataset API errors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.contextType"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.code"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.callSite"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.summary"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI") ) // Default exclude rules diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 05040d8135017..0ea02525f78ff 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xfb\t\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xac\x01\n\x0cQueryContext\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -202,17 +202,19 @@ _FETCHERRORDETAILSREQUEST._serialized_start = 11378 _FETCHERRORDETAILSREQUEST._serialized_end = 11579 _FETCHERRORDETAILSRESPONSE._serialized_start = 11582 - _FETCHERRORDETAILSRESPONSE._serialized_end = 12857 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13052 _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727 _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901 _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12076 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12079 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12488 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12390 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12458 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12491 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12838 - _SPARKCONNECTSERVICE._serialized_start = 12860 - _SPARKCONNECTSERVICE._serialized_end = 13709 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12271 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12234 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12271 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12274 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12683 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12585 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12653 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12686 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13033 + _SPARKCONNECTSERVICE._serialized_start = 13055 + _SPARKCONNECTSERVICE._serialized_end = 13904 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 5d2ebeb573990..c29feb4164cf1 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2885,11 +2885,35 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class _ContextType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _ContextTypeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ + FetchErrorDetailsResponse.QueryContext._ContextType.ValueType + ], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + SQL: FetchErrorDetailsResponse.QueryContext._ContextType.ValueType # 0 + DATAFRAME: FetchErrorDetailsResponse.QueryContext._ContextType.ValueType # 1 + + class ContextType(_ContextType, metaclass=_ContextTypeEnumTypeWrapper): + """The type of this query context.""" + + SQL: FetchErrorDetailsResponse.QueryContext.ContextType.ValueType # 0 + DATAFRAME: FetchErrorDetailsResponse.QueryContext.ContextType.ValueType # 1 + + CONTEXT_TYPE_FIELD_NUMBER: builtins.int OBJECT_TYPE_FIELD_NUMBER: builtins.int OBJECT_NAME_FIELD_NUMBER: builtins.int START_INDEX_FIELD_NUMBER: builtins.int STOP_INDEX_FIELD_NUMBER: builtins.int FRAGMENT_FIELD_NUMBER: builtins.int + CALLSITE_FIELD_NUMBER: builtins.int + SUMMARY_FIELD_NUMBER: builtins.int + context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType object_type: builtins.str """The object type of the query which throws the exception. If the exception is directly from the main query, it should be an empty string. @@ -2906,18 +2930,29 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): """The stopping index in the query which throws the exception. The index starts from 0.""" fragment: builtins.str """The corresponding fragment of the query which throws the exception.""" + callSite: builtins.str + """The user code (call site of the API) that caused throwing the exception.""" + summary: builtins.str + """Summary of the exception cause.""" def __init__( self, *, + context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType = ..., object_type: builtins.str = ..., object_name: builtins.str = ..., start_index: builtins.int = ..., stop_index: builtins.int = ..., fragment: builtins.str = ..., + callSite: builtins.str = ..., + summary: builtins.str = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ + "callSite", + b"callSite", + "context_type", + b"context_type", "fragment", b"fragment", "object_name", @@ -2928,6 +2963,8 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): b"start_index", "stop_index", b"stop_index", + "summary", + b"summary", ], ) -> None: ... diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 51d2b4beab227..22e6c67090b4d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin} import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf @@ -229,7 +229,7 @@ class ParseException( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p), _, _, _, _, _) => + case Origin(Some(l), Some(p), _, _, _, _, _, _) => builder ++= s" (line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) @@ -262,8 +262,7 @@ class ParseException( object ParseException { def getQueryContext(): Array[QueryContext] = { - val context = CurrentOrigin.get.context - if (context.isValid) Array(context) else Array.empty + Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala similarity index 78% rename from sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 5b29cb3dde74f..b8288b24535e8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.trees -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, QueryContextType} /** The class represents error context of a SQL query. */ case class SQLQueryContext( @@ -28,6 +28,7 @@ case class SQLQueryContext( sqlText: Option[String], originObjectType: Option[String], originObjectName: Option[String]) extends QueryContext { + override val contextType = QueryContextType.SQL override val objectType = originObjectType.getOrElse("") override val objectName = originObjectName.getOrElse("") @@ -40,7 +41,7 @@ case class SQLQueryContext( * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i * ^^^^^^^^^^^^^^^ */ - lazy val summary: String = { + override lazy val summary: String = { // If the query context is missing or incorrect, simply return an empty string. if (!isValid) { "" @@ -116,7 +117,7 @@ case class SQLQueryContext( } /** Gets the textual fragment of a SQL query. */ - override lazy val fragment: String = { + lazy val fragment: String = { if (!isValid) { "" } else { @@ -128,6 +129,45 @@ case class SQLQueryContext( sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined && originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && originStartIndex.get <= originStopIndex.get + } + + override def callSite: String = throw new UnsupportedOperationException +} + +case class DataFrameQueryContext( + override val fragment: String, + override val callSite: String) extends QueryContext { + override val contextType = QueryContextType.DataFrame + + override def objectType: String = throw new UnsupportedOperationException + override def objectName: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + + override lazy val summary: String = { + val builder = new StringBuilder + builder ++= "== DataFrame ==\n" + builder ++= "\"" + + builder ++= fragment + builder ++= "\"" + builder ++= " was called from " + builder ++= callSite + builder += '\n' + builder.result() + } +} + +object DataFrameQueryContext { + def apply(elements: Array[StackTraceElement]): DataFrameQueryContext = { + val methodName = elements(0).getMethodName + val code = if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + val callSite = elements(1).toString + DataFrameQueryContext(code, callSite) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index ec3e627ac9585..dd24dae16ba8c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -30,15 +30,21 @@ case class Origin( stopIndex: Option[Int] = None, sqlText: Option[String] = None, objectType: Option[String] = None, - objectName: Option[String] = None) { + objectName: Option[String] = None, + stackTrace: Option[Array[StackTraceElement]] = None) { - lazy val context: SQLQueryContext = SQLQueryContext( - line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) - - def getQueryContext: Array[QueryContext] = if (context.isValid) { - Array(context) + lazy val context: QueryContext = if (stackTrace.isDefined) { + DataFrameQueryContext(stackTrace.get) } else { - Array.empty + SQLQueryContext( + line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) + } + + def getQueryContext: Array[QueryContext] = { + Some(context).filter { + case s: SQLQueryContext => s.isValid + case _ => true + }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 7c1b37e9e5815..99caef978bb4a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.ExecutionErrors /** @@ -27,37 +27,37 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def addExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def addExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def subtractExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def subtractExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def multiplyExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def multiplyExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } @@ -78,7 +78,7 @@ object MathUtils { def withOverflow[A]( f: => A, hint: String = "", - context: SQLQueryContext = null): A = { + context: QueryContext = null): A = { try { f } catch { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 698e7b37a9ef0..f8a9274a5646c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import sun.util.calendar.ZoneInfo -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, rebaseJulianToGregorianDays, rebaseJulianToGregorianMicros} import org.apache.spark.sql.errors.ExecutionErrors @@ -355,7 +355,7 @@ trait SparkDateTimeUtils { def stringToDateAnsi( s: UTF8String, - context: SQLQueryContext = null): Int = { + context: QueryContext = null): Int = { stringToDate(s).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context) } @@ -567,7 +567,7 @@ trait SparkDateTimeUtils { def stringToTimestampAnsi( s: UTF8String, timeZoneId: ZoneId, - context: SQLQueryContext = null): Long = { + context: QueryContext = null): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampType, context) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 5e52e283338d3..b30f7b7a00e91 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLSchema import org.apache.spark.sql.types.{DataType, Decimal, StringType} @@ -191,7 +191,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -199,7 +199,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -207,7 +207,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -222,7 +222,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { val convertedValueStr = "'" + s.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index 911d900053cf9..d1d9dd806b3b8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.errors import java.util.Locale import org.apache.spark.QueryContext -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -89,11 +88,11 @@ private[sql] trait DataTypeErrorsBase { "\"" + elem + "\"" } - def getSummary(sqlContext: SQLQueryContext): String = { + def getSummary(sqlContext: QueryContext): String = { if (sqlContext == null) "" else sqlContext.summary } - def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { - if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) + def getQueryContext(context: QueryContext): Array[QueryContext] = { + if (context == null) Array.empty else Array(context) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index c8321e81027ba..394e56062071b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -21,9 +21,8 @@ import java.time.temporal.ChronoField import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.{SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{DataType, DoubleType, StringType, UserDefinedType} import org.apache.spark.unsafe.types.UTF8String @@ -83,14 +82,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def invalidInputInCastToDatetimeError( value: UTF8String, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), StringType, to, context) } def invalidInputInCastToDatetimeError( value: Double, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context) } @@ -98,7 +97,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { sqlValue: String, from: DataType, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -113,7 +112,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 3c386b20a7912..5652e5adda9d4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,8 +21,8 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try +import org.apache.spark.QueryContext import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.unsafe.types.UTF8String @@ -341,7 +341,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -617,7 +617,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 99117d81b34ad..62295fe260535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,13 +21,13 @@ import java.time.{ZoneId, ZoneOffset} import java.util.Locale import java.util.concurrent.TimeUnit._ -import org.apache.spark.SparkArithmeticException +import org.apache.spark.{QueryContext, SparkArithmeticException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.{PhysicalFractionalType, PhysicalIntegralType, PhysicalNumericType} import org.apache.spark.sql.catalyst.util._ @@ -527,7 +527,7 @@ case class Cast( } } - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -945,7 +945,7 @@ case class Cast( private[this] def toPrecision( value: Decimal, decimalType: DecimalType, - context: SQLQueryContext): Decimal = + context: QueryContext): Decimal = value.toPrecision( decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 46c36ab8e3c31..0dc70c6c3947c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.SparkException +import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString @@ -613,11 +613,11 @@ abstract class UnaryExpression extends Expression with UnaryLike[Expression] { * to executors. It will also be kept after rule transforms. */ trait SupportQueryContext extends Expression with Serializable { - protected var queryContext: Option[SQLQueryContext] = initQueryContext() + protected var queryContext: Option[QueryContext] = initQueryContext() - def initQueryContext(): Option[SQLQueryContext] + def initQueryContext(): Option[QueryContext] - def getContextOrNull(): SQLQueryContext = queryContext.orNull + def getContextOrNull(): QueryContext = queryContext.orNull def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = { if (withErrorContext && queryContext.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index fd6131f185606..fe30e2ea6f3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -134,7 +135,7 @@ case class Average( override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index e3881520e4902..dfd41ad12a280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{EvalMode, _} -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -186,7 +187,7 @@ case class Sum( // The flag `evalMode` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a556ac9f12947..e3c5184c5acc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import scala.math.{max, min} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE} import org.apache.spark.sql.catalyst.types.{PhysicalDecimalType, PhysicalFractionalType, PhysicalIntegerType, PhysicalIntegralType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{IntervalMathUtils, IntervalUtils, MathUtils, TypeUtils} @@ -266,7 +266,7 @@ abstract class BinaryArithmetic extends BinaryOperator final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index adaed0ff819be..25da787b8874f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,13 +22,14 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ @@ -2526,7 +2527,7 @@ case class ElementAt( override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError && left.resolved && left.dataType.isInstanceOf[ArrayType]) { Some(origin.context) } else { @@ -5046,7 +5047,7 @@ case class ArrayInsert( newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert = copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr) - override def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3885a5b9f5b32..a801d0367080d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -316,7 +316,7 @@ case class GetArrayItem( newLeft: Expression, newRight: Expression): GetArrayItem = copy(child = newLeft, ordinal = newRight) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 378920856eb11..5f13d397d1bf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.types.PhysicalDecimalType import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors @@ -146,7 +146,7 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) { + override def initQueryContext(): Option[QueryContext] = if (!nullOnOverflow) { Some(origin.context) } else { None @@ -158,7 +158,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - context: SQLQueryContext) extends UnaryExpression with SupportQueryContext { + context: QueryContext) extends UnaryExpression with SupportQueryContext { override def nullable: Boolean = true @@ -210,7 +210,7 @@ case class CheckOverflowInSum( override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) } /** @@ -256,12 +256,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - context: SQLQueryContext, + context: QueryContext, nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 3750a9271cff0..aa1f6159def8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ @@ -200,9 +200,11 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { */ final def bind( f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction = { - val res = bindInternal(f) - res.copyTagsFrom(this) - res + CurrentOrigin.withOrigin(origin) { + val res = bindInternal(f) + res.copyTagsFrom(this) + res + } } protected def bindInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5378639e6838b..13676733a9bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -22,8 +22,8 @@ import java.util.Locale import com.google.common.math.{DoubleMath, IntMath, LongMath} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ @@ -604,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: SQLQueryContext): Unit = { + context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.intervalArithmeticOverflowError( @@ -616,7 +616,7 @@ trait IntervalDivide { def divideByZeroCheck( dataType: DataType, num: Any, - context: SQLQueryContext): Unit = dataType match { + context: QueryContext): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) { throw QueryExecutionErrors.intervalDividedByZeroError(context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c320d98d9fd1c..0c09e9be12e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import java.util.Locale +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -480,7 +480,7 @@ case class Conv( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -1523,7 +1523,7 @@ abstract class RoundBase(child: Expression, scale: Expression, private lazy val scaleV: Any = scale.eval(EmptyRow) protected lazy val _scale: Int = scaleV.asInstanceOf[Int] - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (ansiEnabled) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f5c34d67d4db1..cf6c7780cc82f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -411,7 +412,7 @@ case class Elt( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = copy(children = newChildren) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 23bbc91c16d54..8fabb44876208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{Decimal, DoubleExactNumeric, TimestampNTZType, TimestampType} @@ -70,7 +70,7 @@ object DateTimeUtils extends SparkDateTimeUtils { // the "GMT" string. For example, it returns 2000-01-01T00:00+01:00 for 2000-01-01T00:00GMT+01:00. def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8) - def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = { + def doubleToTimestampAnsi(d: Double, context: QueryContext): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(d, TimestampType, context) } else { @@ -91,7 +91,7 @@ object DateTimeUtils extends SparkDateTimeUtils { def stringToTimestampWithoutTimeZoneAnsi( s: UTF8String, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampNTZType, context) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 59765cde1f926..2730ab8f4b890 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.types.UTF8String @@ -54,7 +54,7 @@ object NumberConverter { fromPos: Int, value: Array[Byte], ansiEnabled: Boolean, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { var v: Long = 0L // bound will always be positive since radix >= 2 // Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value @@ -134,7 +134,7 @@ object NumberConverter { fromBase: Int, toBase: Int, ansiEnabled: Boolean, - context: SQLQueryContext): UTF8String = { + context: QueryContext): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index f7800469c3528..1c3a5075dab2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.unsafe.types.UTF8String @@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, context: SQLQueryContext): Long = + def toLongExact(s: UTF8String, context: QueryContext): Long = withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, context: SQLQueryContext): Int = + def toIntExact(s: UTF8String, context: QueryContext): Int = withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, context: SQLQueryContext): Short = + def toShortExact(s: UTF8String, context: QueryContext): Short = withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, context: SQLQueryContext): Byte = + def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) private def withException[A]( f: => A, - context: SQLQueryContext, + context: QueryContext, to: DataType, s: UTF8String): A = { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index afc244509c41d..30dfe8eebe6cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode, MapData} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -104,7 +104,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -118,7 +118,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputSyntaxForBooleanError( s: UTF8String, - context: SQLQueryContext): SparkRuntimeException = { + context: QueryContext): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -133,7 +133,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -194,15 +194,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def divideByZeroError(context: SQLQueryContext): ArithmeticException = { + def divideByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = getQueryContext(context), + context = Array(context), summary = getSummary(context)) } - def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { + def intervalDividedByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_DIVIDED_BY_ZERO", messageParameters = Map.empty, @@ -213,7 +213,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidArrayIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Map( @@ -227,7 +227,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidElementAtIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = Map( @@ -292,15 +292,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { + def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { + def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } - def overflowInConvError(context: SQLQueryContext): ArithmeticException = { + def overflowInConvError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in function conv()", context = context) } @@ -625,7 +625,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def intervalArithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" @@ -1391,7 +1391,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "functionName" -> toSQLId(prettyName))) } - def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = { + def invalidIndexOfZeroError(context: QueryContext): RuntimeException = { new SparkRuntimeException( errorClass = "INVALID_INDEX_OF_ZERO", cause = null, @@ -2556,7 +2556,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def multipleRowScalarSubqueryError(context: SQLQueryContext): Throwable = { + def multipleRowScalarSubqueryError(context: QueryContext): Throwable = { new SparkException( errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 997308c6ef44f..ba4e7b279f512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { - import org.apache.spark.QueryContext - protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -177,7 +175,7 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe89..3fd0c1ee5de4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} @@ -159,7 +158,7 @@ abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { super.assertAnalysisErrorClass( @@ -196,7 +195,7 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { super.assertAnalysisErrorClass( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9bb35a8b0b3d1..0ca55ef67fd38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -70,8 +70,10 @@ private[sql] object Column { name: String, isDistinct: Boolean, ignoreNulls: Boolean, - inputs: Column*): Column = Column { - UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + inputs: Column*): Column = withOrigin { + Column { + UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + } } } @@ -148,12 +150,14 @@ class TypedColumn[-T, U]( @Stable class Column(val expr: Expression) extends Logging { - def this(name: String) = this(name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => - val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) - UnresolvedStar(Some(parts)) - case _ => UnresolvedAttribute.quotedString(name) + def this(name: String) = this(withOrigin { + name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + case _ => UnresolvedAttribute.quotedString(name) + } }) private def fn(name: String): Column = { @@ -180,7 +184,9 @@ class Column(val expr: Expression) extends Logging { } /** Creates a column based on the given expression. */ - private def withExpr(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: => Expression): Column = withOrigin { + new Column(newExpr) + } /** * Returns the expression for this column either with an existing or auto assigned name. @@ -1370,7 +1376,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + def over(window: expressions.WindowSpec): Column = withOrigin { + window.withAggregate(this) + } /** * Defines an empty analytic clause. In this case the analytic function is applied diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index f3690773f6ddd..a8c4d4f8d2ba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -72,7 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( col: String, probabilities: Array[Double], - relativeError: Double): Array[Double] = { + relativeError: Double): Array[Double] = withOrigin { approxQuantile(Array(col), probabilities, relativeError).head } @@ -97,7 +97,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( cols: Array[String], probabilities: Array[Double], - relativeError: Double): Array[Array[Double]] = { + relativeError: Double): Array[Array[Double]] = withOrigin { StatFunctions.multipleApproxQuantiles( df.select(cols.map(col): _*), cols, @@ -132,7 +132,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def cov(col1: String, col2: String): Double = { + def cov(col1: String, col2: String): Double = withOrigin { StatFunctions.calculateCov(df, Seq(col1, col2)) } @@ -154,7 +154,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def corr(col1: String, col2: String, method: String): Double = { + def corr(col1: String, col2: String, method: String): Double = withOrigin { require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) @@ -210,7 +210,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): DataFrame = { + def crosstab(col1: String, col2: String): DataFrame = withOrigin { StatFunctions.crossTabulate(df, col1, col2) } @@ -257,7 +257,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): DataFrame = { + def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -276,7 +276,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String]): DataFrame = { + def freqItems(cols: Array[String]): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -320,7 +320,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): DataFrame = { + def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -339,7 +339,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String]): DataFrame = { + def freqItems(cols: Seq[String]): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -415,7 +415,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = withOrigin { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.{rand, udf} @@ -497,7 +497,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return a `CountMinSketch` over column `colName` * @since 2.0.0 */ - def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + def countMinSketch( + col: Column, + eps: Double, + confidence: Double, + seed: Int): CountMinSketch = withOrigin { val countMinSketchAgg = new CountMinSketchAgg( col.expr, Literal(eps, DoubleType), @@ -555,7 +559,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param numBits expected number of bits of the filter. * @since 2.0.0 */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { val bloomFilterAgg = new BloomFilterAggregate( col.expr, Literal(expectedNumItems, LongType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4f07133bb7617..ba5eb790cea9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -508,9 +508,11 @@ class Dataset[T] private[sql]( * @group basic * @since 3.4.0 */ - def to(schema: StructType): DataFrame = withPlan { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + def to(schema: StructType): DataFrame = withOrigin { + withPlan { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + } } /** @@ -770,12 +772,14 @@ class Dataset[T] private[sql]( */ // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. - def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { - val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) - require(!IntervalUtils.isNegative(parsedDelay), - s"delay threshold ($delayThreshold) should not be negative.") - EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withOrigin { + withTypedPlan { + val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) + require(!IntervalUtils.isNegative(parsedDelay), + s"delay threshold ($delayThreshold) should not be negative.") + EliminateEventTimeWatermark( + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + } } /** @@ -947,8 +951,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + def join(right: Dataset[_]): DataFrame = withOrigin { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + } } /** @@ -1081,22 +1087,23 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { - // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right - // by creating a new instance for one of the branch. - val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) - .analyzed.asInstanceOf[Join] - - withPlan { - Join( - joined.left, - joined.right, - UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), - None, - JoinHint.NONE) + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = + withOrigin { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sparkSession.sessionState.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) + .analyzed.asInstanceOf[Join] + + withPlan { + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), + None, + JoinHint.NONE) + } } - } /** * Inner join with another `DataFrame`, using the given join expression. @@ -1177,7 +1184,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = withOrigin { withPlan { resolveSelfJoinCondition(right, Some(joinExprs), joinType) } @@ -1193,8 +1200,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + def crossJoin(right: Dataset[_]): DataFrame = withOrigin { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + } } /** @@ -1218,27 +1227,28 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, - // etc. - val joined = sparkSession.sessionState.executePlan( - Join( - this.logicalPlan, - other.logicalPlan, - JoinType(joinType), - Some(condition.expr), - JoinHint.NONE)).analyzed.asInstanceOf[Join] - - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - - withTypedPlan(JoinWith.typedJoinWith( - joined, - sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, - sparkSession.sessionState.analyzer.resolver, - this.exprEnc.isSerializedAsStructForTopLevel, - other.exprEnc.isSerializedAsStructForTopLevel)) - } + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = + withOrigin { + // Creates a Join node and resolve it first, to get join condition resolved, self-join + // resolved, etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr), + JoinHint.NONE)).analyzed.asInstanceOf[Join] + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + + withTypedPlan(JoinWith.typedJoinWith( + joined, + sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, + sparkSession.sessionState.analyzer.resolver, + this.exprEnc.isSerializedAsStructForTopLevel, + other.exprEnc.isSerializedAsStructForTopLevel)) + } /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair @@ -1416,14 +1426,16 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { - val exprs = parameters.map { - case c: Column => c.expr - case s: Symbol => Column(s.name).expr - case e: Expression => e - case literal => Literal(literal) - }.toSeq - UnresolvedHint(name, exprs, logicalPlan) + def hint(name: String, parameters: Any*): Dataset[T] = withOrigin { + withTypedPlan { + val exprs = parameters.map { + case c: Column => c.expr + case s: Symbol => Column(s.name).expr + case e: Expression => e + case literal => Literal(literal) + }.toSeq + UnresolvedHint(name, exprs, logicalPlan) + } } /** @@ -1499,8 +1511,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + def as(alias: String): Dataset[T] = withOrigin { + withTypedPlan { + SubqueryAlias(alias, logicalPlan) + } } /** @@ -1537,25 +1551,28 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = withPlan { - val untypedCols = cols.map { - case typedCol: TypedColumn[_, _] => - // Checks if a `TypedColumn` has been inserted with - // specific input type and schema by `withInputType`. - val needInputType = typedCol.expr.exists { - case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true - case _ => false - } + def select(cols: Column*): DataFrame = withOrigin { + withPlan { + val untypedCols = cols.map { + case typedCol: TypedColumn[_, _] => + // Checks if a `TypedColumn` has been inserted with + // specific input type and schema by `withInputType`. + val needInputType = typedCol.expr.exists { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true + case _ => false + } - if (!needInputType) { - typedCol - } else { - throw QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) - } + if (!needInputType) { + typedCol + } else { + throw + QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) + } - case other => other + case other => other + } + Project(untypedCols.map(_.named), logicalPlan) } - Project(untypedCols.map(_.named), logicalPlan) } /** @@ -1572,7 +1589,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) + def select(col: String, cols: String*): DataFrame = withOrigin { + select((col +: cols).map(Column(_)) : _*) + } /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -1588,10 +1607,12 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): DataFrame = sparkSession.withActive { - select(exprs.map { expr => - Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) - }: _*) + def selectExpr(exprs: String*): DataFrame = withOrigin { + sparkSession.withActive { + select(exprs.map { expr => + Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) + }: _*) + } } /** @@ -1605,7 +1626,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = withOrigin { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) @@ -1689,8 +1710,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + def filter(condition: Column): Dataset[T] = withOrigin { + withTypedPlan { + Filter(condition.expr, logicalPlan) + } } /** @@ -2049,15 +2072,17 @@ class Dataset[T] private[sql]( ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - Some(values.map(v => Seq(v.named))), - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin { + withPlan { + Unpivot( + Some(ids.map(_.named)), + Some(values.map(v => Seq(v.named))), + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2080,15 +2105,17 @@ class Dataset[T] private[sql]( def unpivot( ids: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - None, - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin { + withPlan { + Unpivot( + Some(ids.map(_.named)), + None, + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2205,8 +2232,10 @@ class Dataset[T] private[sql]( * @since 3.0.0 */ @varargs - def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withOrigin { + withTypedPlan { + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + } } /** @@ -2243,8 +2272,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + def limit(n: Int): Dataset[T] = withOrigin { + withTypedPlan { + Limit(Literal(n), logicalPlan) + } } /** @@ -2253,8 +2284,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = withTypedPlan { - Offset(Literal(n), logicalPlan) + def offset(n: Int): Dataset[T] = withOrigin { + withTypedPlan { + Offset(Literal(n), logicalPlan) + } } // This breaks caching, but it's usually ok because it addresses a very specific use case: @@ -2664,20 +2697,20 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = { - val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - - val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) - - val rowFunction = - f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) - - withPlan { - Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = + withOrigin { + val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) + + val rowFunction = + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) + val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) + + withPlan { + Generate(generator, unrequiredChildIndex = Nil, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } - } /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero @@ -2702,7 +2735,7 @@ class Dataset[T] private[sql]( */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => IterableOnce[B]) - : DataFrame = { + : DataFrame = withOrigin { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? @@ -2859,7 +2892,7 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin { val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output @@ -3073,9 +3106,11 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - Deduplicate(groupCols, logicalPlan) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = withOrigin { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + Deduplicate(groupCols, logicalPlan) + } } /** @@ -3151,10 +3186,12 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. - DeduplicateWithinWatermark(groupCols, logicalPlan) + def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withOrigin { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. + DeduplicateWithinWatermark(groupCols, logicalPlan) + } } /** @@ -3378,7 +3415,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: T => Boolean): Dataset[T] = { + def filter(func: T => Boolean): Dataset[T] = withOrigin { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3389,7 +3426,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: FilterFunction[T]): Dataset[T] = { + def filter(func: FilterFunction[T]): Dataset[T] = withOrigin { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3400,8 +3437,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + def map[U : Encoder](func: T => U): Dataset[U] = withOrigin { + withTypedPlan { + MapElements[T, U](func, logicalPlan) + } } /** @@ -3411,7 +3450,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = withOrigin { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) } @@ -3574,8 +3613,9 @@ class Dataset[T] private[sql]( * @group action * @since 3.0.0 */ - def tail(n: Int): Array[T] = withAction( - "tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + def tail(n: Int): Array[T] = withOrigin { + withAction("tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + } /** * Returns the first `n` rows in the Dataset as a list. @@ -3639,8 +3679,10 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def count(): Long = withAction("count", groupBy().count().queryExecution) { plan => - plan.executeCollect().head.getLong(0) + def count(): Long = withOrigin { + withAction("count", groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) + } } /** @@ -3649,13 +3691,15 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + def repartition(numPartitions: Int): Dataset[T] = withOrigin { + withTypedPlan { + Repartition(numPartitions, shuffle = true, logicalPlan) + } } private def repartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. // However, we don't want to complicate the semantics of this API method. // Instead, let's give users a friendly error message, pointing them to the new method. @@ -3700,7 +3744,7 @@ class Dataset[T] private[sql]( private def repartitionByRange( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin { require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { case expr: SortOrder => expr @@ -3772,8 +3816,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + def coalesce(numPartitions: Int): Dataset[T] = withOrigin { + withTypedPlan { + Repartition(numPartitions, shuffle = false, logicalPlan) + } } /** @@ -3917,8 +3963,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @throws[AnalysisException] - def createTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = false) + def createTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = false, global = false) + } } @@ -3930,8 +3978,10 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - def createOrReplaceTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = true, global = false) + def createOrReplaceTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = true, global = false) + } } /** @@ -3949,8 +3999,10 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @throws[AnalysisException] - def createGlobalTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = true) + def createGlobalTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = false, global = true) + } } /** @@ -4358,7 +4410,7 @@ class Dataset[T] private[sql]( plan.executeCollect().map(fromRow) } - private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = withOrigin { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 58f720154df52..771c743f70629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.QueryContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression, Predicate, SupportQueryContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{LeafLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ case class ScalarSubquery( override def nullable: Boolean = true override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query) - def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + def initQueryContext(): Option[QueryContext] = Some(origin.context) override lazy val canonicalized: Expression = { ScalarSubquery(plan.canonicalized.asInstanceOf[BaseSubqueryExec], ExprId(0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b5e40fe35cfe1..a42df5bbcc292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -93,11 +93,13 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: Expression): Column = Column(expr) + private def withExpr(expr: => Expression): Column = withOrigin { + Column(expr) + } private def withAggregateFunction( - func: AggregateFunction, - isDistinct: Boolean = false): Column = { + func: => AggregateFunction, + isDistinct: Boolean = false): Column = withOrigin { Column(func.toAggregateExpression(isDistinct)) } @@ -127,16 +129,18 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => - // This is different from `typedlit`. `typedlit` calls `Literal.create` to use - // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method, - // `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can - // just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is - // significantly better when there are many threads calling `lit` concurrently. + def lit(literal: Any): Column = withOrigin { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => + // This is different from `typedlit`. `typedlit` calls `Literal.create` to use + // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this + // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, + // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. + // This is significantly better when there are many threads calling `lit` concurrently. Column(Literal(literal)) + } } /** @@ -147,7 +151,9 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = typedlit(literal) + def typedLit[T : TypeTag](literal: T): Column = withOrigin { + typedlit(literal) + } /** * Creates a [[Column]] of literal value. @@ -164,10 +170,12 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) + def typedlit[T : TypeTag](literal: T): Column = withOrigin { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) + } } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -5965,25 +5973,31 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val function = f(Column(x)).expr - LambdaFunction(function, Seq(x)) + private def createLambda(f: Column => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } } - private def createLambda(f: (Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) + private def createLambda(f: (Column, Column) => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } } - private def createLambda(f: (Column, Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) - val function = f(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) + private def createLambda(f: (Column, Column, Column) => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 1794ac513749f..7f00f6d6317c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.util.regex.Pattern + import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkStrategy /** @@ -73,4 +76,41 @@ package object sql { * with rebasing. */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" + + /** + * This helper function captures the Spark API and its call site in the user code from the current + * stacktrace. + * + * As adding `withOrigin` explicitly to all Spark API definition would be a huge change, + * `withOrigin` is used only at certain places where all API implementation surely pass through + * and the current stacktrace is filtered to the point where first Spark API code is invoked from + * the user code. + * + * As there might be multiple nested `withOrigin` calls (e.g. any Spark API implementations can + * invoke other APIs) only the first `withOrigin` is captured because that is closer to the user + * code. + * + * @param f The function that can use the origin. + * @return The result of `f`. + */ + private[sql] def withOrigin[T](f: => T): T = { + if (CurrentOrigin.get.stackTrace.isDefined) { + f + } else { + val st = Thread.currentThread().getStackTrace + var i = 3 + while (i < st.length && sparkCode(st(i))) i += 1 + val origin = + Origin(stackTrace = Some(Thread.currentThread().getStackTrace.slice(i - 1, i + 1))) + CurrentOrigin.withOrigin(origin)(f) + } + } + + private val sparkCodePattern = Pattern.compile("org\\.apache\\.spark\\.sql\\." + + "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions)" + + "(?:|\\..*|\\$.*)") + + private def sparkCode(ste: StackTraceElement): Boolean = { + sparkCodePattern.matcher(ste.getClassName).matches() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8c9ad2180faa3..140daced32234 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -458,7 +458,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext(fragment = "isin", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -525,7 +526,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext( + fragment = "isInCollection", + callSitePattern = getCurrentClassCallSitePattern) ) } } @@ -1056,7 +1060,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "withField", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1101,7 +1108,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "withField", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1849,7 +1859,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1886,7 +1899,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1952,7 +1968,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -2224,7 +2243,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".dropFields("a", "b")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( @@ -2398,7 +2420,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`")) + parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`"), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( structLevel1 @@ -2451,7 +2476,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".withField("z", $"a.c")).as("a") }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) + parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`"), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) } test("nestedDf should generate nested DataFrames") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index c40ecb88257f2..e7c1f0414b619 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -52,7 +52,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" - ) + ), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -395,7 +396,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_csv($"csv", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -403,7 +405,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { Seq("1").toDF("csv").select(from_csv($"csv", lit(1), options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"1\"") + parameters = Map("inputSchema" -> "\"1\""), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f34d7cf368072..c8eea985c1065 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -633,7 +633,10 @@ class DataFrameAggregateSuite extends QueryTest "functionName" -> "`collect_set`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"collect_set(b)\"" - ) + ), + context = ExpectedContext( + fragment = "collect_set", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -706,7 +709,8 @@ class DataFrameAggregateSuite extends QueryTest testData.groupBy(sum($"key")).count() }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "sum(key)") + parameters = Map("sqlExpr" -> "sum(key)"), + context = ExpectedContext(fragment = "sum", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1302,7 +1306,8 @@ class DataFrameAggregateSuite extends QueryTest "paramIndex" -> "2", "inputSql" -> "\"a\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"INTEGRAL\"")) + "requiredType" -> "\"INTEGRAL\""), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 789196583c600..135ce834bfe5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -171,7 +171,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"k\"", "inputType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext( + fragment = "map_from_arrays", + callSitePattern = getCurrentClassCallSitePattern)) ) val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") @@ -758,7 +762,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("The given function only supports array input") { val df = Seq(1, 2, 3).toDF("a") - checkErrorMatchPVals( + checkError( exception = intercept[AnalysisException] { df.select(array_sort(col("a"), (x, y) => x - y)) }, @@ -769,7 +773,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + matchPVals = true, + queryContext = Array( + ExpectedContext( + fragment = "array_sort", + callSitePattern = getCurrentClassCallSitePattern)) + ) } test("sort_array/array_sort functions") { @@ -1305,7 +1315,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext( + fragment = "map_concat", + callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -1333,7 +1347,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext( + fragment = "map_concat", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1402,7 +1420,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"a\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of pair \"STRUCT\"" - ) + ), + context = + ExpectedContext( + fragment = "map_from_entries", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1439,7 +1461,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" - ) + ), + context = + ExpectedContext( + fragment = "array_contains", + callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2348,7 +2374,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"ARRAY\"")) + "rightType" -> "\"ARRAY\""), + context = + ExpectedContext( + fragment = "array_union", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { @@ -2379,7 +2409,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"VOID\"") + "rightType" -> "\"VOID\""), + context = ExpectedContext( + fragment = "array_union", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2410,7 +2442,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY>\"", - "rightType" -> "\"ARRAY\"") + "rightType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext( + fragment = "array_union", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2647,7 +2681,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"arr\"", "inputType" -> "\"ARRAY\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + context = ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2660,7 +2696,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2673,7 +2711,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"s\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2782,7 +2822,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"b\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + context = ExpectedContext( + fragment = "array_repeat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2795,7 +2837,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"1\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_repeat", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3123,7 +3167,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3151,7 +3197,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3179,7 +3227,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"VOID\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3207,7 +3257,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3276,7 +3328,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = ExpectedContext( + fragment = "array_intersect", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3305,7 +3359,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3334,7 +3390,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext( + fragment = "array_intersect", + callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3750,7 +3810,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -3933,7 +3995,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -3945,7 +4009,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + ExpectedContext( + fragment = "filter(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -3957,7 +4025,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "filter", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = @@ -4112,7 +4183,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "exists", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4124,7 +4197,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "exists(s, x -> x)", + start = 0, + stop = 16) + ) checkError( exception = intercept[AnalysisException] { @@ -4136,7 +4214,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(fragment = "exists", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("exists(a, x -> x)")), @@ -4304,7 +4384,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "forall", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4316,7 +4398,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "forall(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -4328,7 +4414,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(fragment = "forall", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("forall(a, x -> x)")), @@ -4343,7 +4431,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(forall(col("a"), x => x))), errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`")) + parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), + queryContext = Array( + ExpectedContext(fragment = "col", callSitePattern = getCurrentClassCallSitePattern))) } test("aggregate function - array for primitive type not containing null") { @@ -4581,7 +4671,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "aggregate", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit // scalastyle:off line.size.limit @@ -4597,7 +4689,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = s"$agg(s, 0, (acc, x) -> x)", + start = 0, + stop = agg.length + 20)) } // scalastyle:on line.size.limit @@ -4613,7 +4709,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = + ExpectedContext(fragment = "aggregate", callSitePattern = getCurrentClassCallSitePattern)) // scalastyle:on line.size.limit Seq("aggregate", "reduce").foreach { agg => @@ -4719,7 +4817,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "functionName" -> "`map_zip_with`", "leftType" -> "\"INT\"", - "rightType" -> "\"MAP\"")) + "rightType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4749,7 +4849,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(i, mis, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4779,7 +4881,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, i, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "2", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -5235,7 +5339,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"x\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext( + fragment = "transform_values", + callSitePattern = getCurrentClassCallSitePattern))) } testInvalidLambdaFunctions() @@ -5375,7 +5483,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "zip_with", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -5631,7 +5741,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map(m, 1)\"", "keyType" -> "\"MAP\"" - ) + ), + context = + ExpectedContext(fragment = "map", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( df.select(map(map_entries($"m"), lit(1))), @@ -5753,7 +5865,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = "array_compact", + callSitePattern = getCurrentClassCallSitePattern)) } test("array_append -> Unit Test cases for the function ") { @@ -5772,7 +5887,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "dataType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"INT\"", - "sqlExpr" -> "\"array_append(a, b)\"") + "sqlExpr" -> "\"array_append(a, b)\""), + context = + ExpectedContext(fragment = "array_append", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer(df1.selectExpr("array_append(a, 3)"), Seq(Row(Seq(3, 2, 5, 1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 237915fb63fa8..b3bf9405a99f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -310,7 +310,8 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { .agg(sum($"sales.earnings")) }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "min(training)") + parameters = Map("sqlExpr" -> "min(training)"), + context = ExpectedContext(fragment = "min", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 88ef5936264de..c777d2207584d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -484,7 +484,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { checkError(ex, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`df1`.`timeStr`", - "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`")) + "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index ab8aab0713a44..e7c1d2c772c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -373,7 +373,10 @@ class DataFrameSetOperationsSuite extends QueryTest errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", - "dataType" -> "\"MAP\"") + "dataType" -> "\"MAP\""), + context = ExpectedContext( + fragment = "distinct", + callSitePattern = getCurrentClassCallSitePattern) ) withTempView("v") { df.createOrReplaceTempView("v") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 430e36221025a..20ac2a9e9461d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -146,7 +146,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = + ExpectedContext(fragment = "freqItems", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -156,7 +158,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = ExpectedContext( + fragment = "approxQuantile", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b7450e5648727..b0a0b189cb7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -354,7 +354,8 @@ class DataFrameSuite extends QueryTest "paramIndex" -> "1", "inputSql"-> "\"csv\"", "inputType" -> "\"STRING\"", - "requiredType" -> "(\"ARRAY\" or \"MAP\")") + "requiredType" -> "(\"ARRAY\" or \"MAP\")"), + context = ExpectedContext(fragment = "explode", getCurrentClassCallSitePattern) ) val df2 = Seq(Array("1", "2"), Array("4"), Array("7", "8", "9")).toDF("csv") @@ -2947,7 +2948,8 @@ class DataFrameSuite extends QueryTest df.groupBy($"d", $"b").as[GroupByKey, Row] }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) + parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 2a81f7e7c2f34..bb744cfd8ab4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -184,7 +184,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -198,7 +200,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -212,7 +216,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) } @@ -240,7 +246,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -255,7 +262,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -270,7 +278,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -462,7 +471,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "upper" -> "\"2\"", "lowerType" -> "\"INTERVAL\"", "upperType" -> "\"BIGINT\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -481,7 +491,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"RANGE BETWEEN nonfoldableliteral() FOLLOWING AND 2 FOLLOWING\"", "location" -> "lower", - "expression" -> "\"nonfoldableliteral()\"") + "expression" -> "\"nonfoldableliteral()\""), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 84133eb485f0f..6969c4303e01e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -412,7 +412,10 @@ class DataFrameWindowFunctionsSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid`", - "proposal" -> "`value`, `key`")) + "proposal" -> "`value`, `key`"), + context = ExpectedContext( + fragment = "count", + callSitePattern = getCurrentClassCallSitePattern)) } test("numerical aggregate functions on string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bf78e6e11fe99..66105d2ac429f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -606,7 +606,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy function, flatMapSorted") { @@ -634,7 +635,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy, flatMapSorted desc") { @@ -2290,7 +2292,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> s"`${colName.replace(".", "`.`")}`", - "proposal" -> "`field.1`, `field 2`")) + "proposal" -> "`field.1`, `field 2`"), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } } } @@ -2304,7 +2307,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`the`.`id`", - "proposal" -> "`the.id`")) + "proposal" -> "`the.id`"), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } test("SPARK-39783: backticks in error message for map candidate key with dots") { @@ -2318,7 +2322,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`nonexisting`", - "proposal" -> "`map`, `other.column`")) + "proposal" -> "`map`, `other.column`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy.as") { @@ -2659,6 +2664,22 @@ class DatasetSuite extends QueryTest assert(join.count() == 1000000) } + test("SPARK-45022: exact DatasetQueryContext call site") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = Seq(1).toDS() + var callSitePattern: String = null + checkError( + exception = intercept[AnalysisException] { + callSitePattern = getNextLineCallSitePattern() + val c = col("a") + df.select(c) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), + context = ExpectedContext(fragment = "col", callSitePattern = callSitePattern)) + } + } } class DatasetLargeResultCollectingSuite extends QueryTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala index 4117ea63bdd8c..49811d8ac61bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -373,7 +373,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`1`", - "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`")) + "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting where value column does not exist val e2 = intercept[AnalysisException] { @@ -389,7 +390,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`does`", - "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`")) + "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting without values where potential value columns are of incompatible types val e3 = intercept[AnalysisException] { @@ -506,7 +508,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`an`.`id`", - "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`")) + "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("unpivot with struct fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 9b4ad76881864..2ab651237206a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -293,7 +293,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"array()\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"ARRAY\"") + "requiredType" -> "\"ARRAY\""), + context = ExpectedContext( + fragment = "inline", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -331,7 +334,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(b))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext( + fragment = "array", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct(Symbol("a")), struct(Symbol("b").alias("a"))))), @@ -346,7 +352,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(2))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext( + fragment = "array", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct(Symbol("a")), struct(lit(2).alias("a"))))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 5effa2edf585c..933f362db663f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -195,7 +195,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"json_tuple(a, 1)\"", "funcName" -> "`json_tuple`" - ) + ), + context = + ExpectedContext(fragment = "json_tuple", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -648,7 +650,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.INVALID_JSON_MAP_KEY_TYPE", parameters = Map( "schema" -> "\"MAP, STRING>\"", - "sqlExpr" -> "\"entries\"")) + "sqlExpr" -> "\"entries\""), + context = + ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-24709: infers schemas of json strings and pass them to from_json") { @@ -958,7 +962,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_json($"json", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(fragment = "from_json", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 2a24f0cc39965..afbe9cdac6366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -600,6 +600,10 @@ class ParametersSuite extends QueryTest with SharedSparkSession { array(str_to_map(Column(Literal("a:1,b:2,c:3"))))))) }, errorClass = "INVALID_SQL_ARG", - parameters = Map("name" -> "m")) + parameters = Map("name" -> "m"), + context = ExpectedContext( + fragment = "map_from_arrays", + callSitePattern = getCurrentClassCallSitePattern) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5ae9c7f35200..8668d61317409 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.TimeZone +import java.util.regex.Pattern import scala.jdk.CollectionConverters._ @@ -229,6 +230,17 @@ abstract class QueryTest extends PlanTest { assert(query.queryExecution.executedPlan.missingInput.isEmpty, s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") } + + protected def getCurrentClassCallSitePattern: String = { + val cs = Thread.currentThread().getStackTrace()(2) + s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)" + } + + protected def getNextLineCallSitePattern(lines: Int = 1): String = { + val cs = Thread.currentThread().getStackTrace()(2) + Pattern.quote( + s"${cs.getClassName}.${cs.getMethodName}(${cs.getFileName}:${cs.getLineNumber + lines})") + } } object QueryTest extends Assertions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3612f4a7eda8d..b7201c2d96d77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1920,7 +1920,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark dfNoCols.select($"b.*") }, errorClass = "CANNOT_RESOLVE_STAR_EXPAND", - parameters = Map("targetString" -> "`b`", "columns" -> "")) + parameters = Map("targetString" -> "`b`", "columns" -> ""), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 179f40742c28f..38a6b9a50272b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -879,7 +879,10 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "funcName" -> s"`$funcName`", "paramName" -> "`format`", - "paramType" -> "\"STRING\"")) + "paramType" -> "\"STRING\""), + context = ExpectedContext( + fragment = funcName, + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() @@ -888,7 +891,10 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "parameter" -> "`format`", "functionName" -> s"`$funcName`", - "invalidFormat" -> "'invalid_format'")) + "invalidFormat" -> "'invalid_format'"), + context = ExpectedContext( + fragment = funcName, + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { sql(s"select $funcName('a', 'b', 'c')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 5c12ba3078069..30a5bf709066d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -696,7 +696,9 @@ class QueryCompilationErrorsSuite Seq("""{"a":1}""").toDF("a").select(from_json($"a", IntegerType)).collect() }, errorClass = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", - parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\"")) + parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\""), + context = + ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("WRONG_NUM_ARGS.WITHOUT_SUGGESTION: wrong args of CAST(parameter types contains DataType)") { @@ -767,7 +769,8 @@ class QueryCompilationErrorsSuite }, errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", - parameters = Map("field" -> "`firstname`", "count" -> "2") + parameters = Map("field" -> "`firstname`", "count" -> "2"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -780,7 +783,9 @@ class QueryCompilationErrorsSuite }, errorClass = "INVALID_EXTRACT_BASE_FIELD_TYPE", sqlState = "42000", - parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\"")) + parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\""), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) + ) } test("INVALID_EXTRACT_FIELD_TYPE: extract not string literal field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index ee28a90aed9af..eafa89e8e007e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.ByteType @@ -53,6 +55,26 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "6/0", start = 7, stop = 9)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5) / lit(0)).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext(fragment = "div", callSitePattern = getCurrentClassCallSitePattern)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5).divide(lit(0))).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext( + fragment = "divide", + callSitePattern = getCurrentClassCallSitePattern)) } test("INTERVAL_DIVIDED_BY_ZERO: interval divided by zero") { @@ -92,6 +114,21 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('66666666666666.666' AS DECIMAL(8, 1))", start = 7, stop = 49)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit("66666666666666.666").cast("DECIMAL(8, 1)")).collect() + }, + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + sqlState = "22003", + parameters = Map( + "value" -> "66666666666666.666", + "precision" -> "8", + "scale" -> "1", + "config" -> ansiConf), + context = ExpectedContext( + fragment = "cast", + callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX: get element from array") { @@ -102,6 +139,16 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest errorClass = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "array(1, 2, 3, 4, 5)[8]", start = 7, stop = 29)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(lit(Array(1, 2, 3, 4, 5))(8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = ExpectedContext( + fragment = "apply", + callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX_IN_ELEMENT_AT: element_at from array") { @@ -115,6 +162,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "element_at(array(1, 2, 3, 4, 5), 8)", start = 7, stop = 41)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = + ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") { @@ -129,6 +185,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest start = 7, stop = 41) ) + + checkError( + exception = intercept[SparkRuntimeException]( + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 0)).collect() + ), + errorClass = "INVALID_INDEX_OF_ZERO", + parameters = Map.empty, + context = + ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("CAST_INVALID_INPUT: cast string to double") { @@ -146,6 +211,20 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('111111111111xe23' AS DOUBLE)", start = 7, stop = 40)) + + checkError( + exception = intercept[SparkNumberFormatException] { + OneRowRelation().select(lit("111111111111xe23").cast("DOUBLE")).collect() + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'111111111111xe23'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"DOUBLE\"", + "ansiConfig" -> ansiConf), + context = ExpectedContext( + fragment = "cast", + callSitePattern = getCurrentClassCallSitePattern)) } test("CANNOT_PARSE_TIMESTAMP: parse string to timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 75e5d4d452e15..a7cab381c7f6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1206,7 +1206,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v5", fragment = "1/0", @@ -1225,7 +1225,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v1", fragment = "1/0", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 9f2d202299557..0e4985bac9941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -244,7 +244,9 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", METADATA_FILE_NAME).collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = + ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) } metadataColumnsTest("SPARK-42683: df metadataColumn - schema conflict", @@ -522,14 +524,20 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", "_metadata.file_name").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = ExpectedContext( + fragment = "select", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df.select("name", "_METADATA.file_NAME").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`"), + context = ExpectedContext( + fragment = "select", + callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 111e88d57c784..a84aea2786823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2697,7 +2697,9 @@ abstract class CSVSuite readback.filter($"AAA" === 2 && $"bbb" === 3).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f0561a30727b3..2f8b0a323dc8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3105,7 +3105,9 @@ abstract class JsonSuite readback.filter($"AAA" === 0 && $"bbb" === 1).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // Schema inferring val readback2 = spark.read.json(path.getCanonicalPath) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala index 2465dee230de9..d3e9819b9a054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala @@ -133,7 +133,8 @@ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkS parameters = Map( "fieldName" -> "`row_index`", "fields" -> ("`file_path`, `file_name`, `file_size`, " + - "`file_block_start`, `file_block_length`, `file_modification_time`"))) + "`file_block_start`, `file_block_length`, `file_modification_time`")), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index baffc5088eb24..94535bc84a4c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1918,7 +1918,9 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Seq("xyz").toDF().select("value", "default").write.insertInto("t") }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`default`", "proposal" -> "`value`")) + parameters = Map("objectName" -> "`default`", "proposal" -> "`value`"), + context = + ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 2174e91cb4435..66d37e996a6cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -713,7 +713,9 @@ class StreamSuite extends StreamTest { "columnName" -> "`rn_col`", "windowSpec" -> ("(PARTITION BY COL1 ORDER BY COL2 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING " + - "AND CURRENT ROW)"))) + "AND CURRENT ROW)")), + queryContext = Array( + ExpectedContext(fragment = "withColumn", callSitePattern = getCurrentClassCallSitePattern))) } From e6b4fa835de3f6d0057bf3809ea369d785967bcd Mon Sep 17 00:00:00 2001 From: chenyu <119398199+chenyu-opensource@users.noreply.github.com> Date: Wed, 1 Nov 2023 17:12:45 +0800 Subject: [PATCH 04/13] [SPARK-45751][DOCS] Update the default value for spark.executor.logs.rolling.maxRetainedFile **What changes were proposed in this pull request?** The PR updates the default value of 'spark.executor.logs.rolling.maxRetainedFiles' in configuration.html on the website **Why are the changes needed?** The default value of 'spark.executor.logs.rolling.maxRetainedFiles' is -1, but the website is wrong. **Does this PR introduce any user-facing change?** No **How was this patch tested?** It doesn't need to. **Was this patch authored or co-authored using generative AI tooling?** No Closes #43618 from chenyu-opensource/branch-SPARK-45751. Authored-by: chenyu <119398199+chenyu-opensource@users.noreply.github.com> Signed-off-by: Kent Yao --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index bd908f3b34d78..60cad24e71c44 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -666,7 +666,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.logs.rolling.maxRetainedFiles - (none) + -1 Sets the number of latest rolling log files that are going to be retained by the system. Older log files will be deleted. Disabled by default. From c7bba9bfcc350bd3508dd6bb41da6f0c1fef63c6 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 1 Nov 2023 19:24:57 +0800 Subject: [PATCH 05/13] [SPARK-45755][SQL] Improve `Dataset.isEmpty()` by applying global limit `1` ### What changes were proposed in this pull request? This PR makes `Dataset.isEmpty()` to execute global limit 1 first. `LimitPushDown` may push down global limit 1 to lower nodes to improve query performance. Note that we use global limit 1 here, because the local limit cannot be pushed down the group only case: https://github.com/apache/spark/blob/89ca8b6065e9f690a492c778262080741d50d94d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L766-L770 ### Why are the changes needed? Improve query performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual testing: ```scala spark.range(300000000).selectExpr("id", "array(id, id % 10, id % 100) as eo").write.saveAsTable("t1") spark.range(100000000).selectExpr("id", "array(id, id % 10, id % 1000) as eo").write.saveAsTable("t2") println(spark.sql("SELECT * FROM t1 LATERAL VIEW explode_outer(eo) AS e UNION SELECT * FROM t2 LATERAL VIEW explode_outer(eo) AS e").isEmpty) ``` Before this PR | After this PR -- | -- image | image ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43617 from wangyum/SPARK-45755. Lead-authored-by: Yuming Wang Co-authored-by: Yuming Wang Signed-off-by: Jiaan Geng --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ba5eb790cea9c..a567a915daf66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -652,7 +652,7 @@ class Dataset[T] private[sql]( * @group basic * @since 2.4.0 */ - def isEmpty: Boolean = withAction("isEmpty", select().queryExecution) { plan => + def isEmpty: Boolean = withAction("isEmpty", select().limit(1).queryExecution) { plan => plan.executeTake(1).isEmpty } From 86510dbd8949408d14d5c3f51b2d66fd44a46d05 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 1 Nov 2023 08:12:56 -0700 Subject: [PATCH 06/13] [SPARK-45743][BUILD] Upgrade dropwizard metrics 4.2.21 ### What changes were proposed in this pull request? This pr upgrade dropwizard metrics from 4.2.19 to 4.2.21. ### Why are the changes needed? The new version includes the following major updates: - https://github.com/dropwizard/metrics/pull/2652 - https://github.com/dropwizard/metrics/pull/3515 - https://github.com/dropwizard/metrics/pull/3523 - https://github.com/dropwizard/metrics/pull/3570 The full release notes as follows: - https://github.com/dropwizard/metrics/releases/tag/v4.2.20 - https://github.com/dropwizard/metrics/releases/tag/v4.2.21 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43608 from LuciferYang/SPARK-45743. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 10 +++++----- pom.xml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 6fa0f738cf120..b1c1ed8d81dae 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -181,11 +181,11 @@ log4j-core/2.21.0//log4j-core-2.21.0.jar log4j-slf4j2-impl/2.21.0//log4j-slf4j2-impl-2.21.0.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -metrics-core/4.2.19//metrics-core-4.2.19.jar -metrics-graphite/4.2.19//metrics-graphite-4.2.19.jar -metrics-jmx/4.2.19//metrics-jmx-4.2.19.jar -metrics-json/4.2.19//metrics-json-4.2.19.jar -metrics-jvm/4.2.19//metrics-jvm-4.2.19.jar +metrics-core/4.2.21//metrics-core-4.2.21.jar +metrics-graphite/4.2.21//metrics-graphite-4.2.21.jar +metrics-jmx/4.2.21//metrics-jmx-4.2.21.jar +metrics-json/4.2.21//metrics-json-4.2.21.jar +metrics-jvm/4.2.21//metrics-jvm-4.2.21.jar minlog/1.3.0//minlog-1.3.0.jar netty-all/4.1.100.Final//netty-all-4.1.100.Final.jar netty-buffer/4.1.100.Final//netty-buffer-4.1.100.Final.jar diff --git a/pom.xml b/pom.xml index e29d81f6887c5..811d075e203a1 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,7 @@ If you change codahale.metrics.version, you also need to change the link to metrics.dropwizard.io in docs/monitoring.md. --> - 4.2.19 + 4.2.21 1.11.3 1.12.0 From 63cf0fec05e6caf53450dc53828ed6e95190664d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Nov 2023 08:17:33 -0700 Subject: [PATCH 07/13] [SPARK-45753][CORE] Support `spark.deploy.driverIdPattern` ### What changes were proposed in this pull request? This PR aims to support `spark.deploy.driverIdPattern` for Apache Spark 4.0.0. ### Why are the changes needed? This allows the users to be able to control driver ID pattern. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43615 from dongjoon-hyun/SPARK-45753. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/master/Master.scala | 4 +++- .../org/apache/spark/internal/config/Deploy.scala | 8 ++++++++ .../apache/spark/deploy/master/MasterSuite.scala | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 29022c7419b4b..0a66cc974da7c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -53,6 +53,8 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val driverIdPattern = conf.get(DRIVER_ID_PATTERN) + // For application IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) @@ -1175,7 +1177,7 @@ private[deploy] class Master( } private def newDriverId(submitDate: Date): String = { - val appId = "driver-%s-%04d".format(createDateFormat.format(submitDate), nextDriverNumber) + val appId = driverIdPattern.format(createDateFormat.format(submitDate), nextDriverNumber) nextDriverNumber += 1 appId } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index aaeb37a17249a..bffdc79175bd9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -82,4 +82,12 @@ private[spark] object Deploy { .checkValue(_ > 0, "The maximum number of running drivers should be positive.") .createWithDefault(Int.MaxValue) + val DRIVER_ID_PATTERN = ConfigBuilder("spark.deploy.driverIdPattern") + .doc("The pattern for driver ID generation based on Java `String.format` method. " + + "The default value is `driver-%s-%04d` which represents the existing driver id string " + + ", e.g., `driver-20231031224459-0019`. Please be careful to generate unique IDs") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") + .createWithDefault("driver-%s-%04d") } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index fc6c7d267e6a5..cef0e84f20f7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -802,6 +802,7 @@ class MasterSuite extends SparkFunSuite private val _waitingDrivers = PrivateMethod[mutable.ArrayBuffer[DriverInfo]](Symbol("waitingDrivers")) private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) + private val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -1236,6 +1237,20 @@ class MasterSuite extends SparkFunSuite private def getState(master: Master): RecoveryState.Value = { master.invokePrivate(_state()) } + + test("SPARK-45753: Support driver id pattern") { + val master = makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my-driver-%2$05d")) + val submitDate = new Date() + assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00000") + assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00001") + } + + test("SPARK-45753: Prevent invalid driver id patterns") { + val m = intercept[IllegalArgumentException] { + makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my driver")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) From 3cac0df1f5f818e5e4722a435fb83c44ca155883 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 1 Nov 2023 23:45:52 +0800 Subject: [PATCH 08/13] [SPARK-45327][BUILD] Upgrade zstd-jni to 1.5.5-7 ### What changes were proposed in this pull request? The pr aims to upgrade zstd-jni from 1.5.5-5 to 1.5.5-7. ### Why are the changes needed? 1.Version compare: - v1.5.5-6 VS v1.5.5-7: https://github.com/luben/zstd-jni/compare/v1.5.5-6...v1.5.5-7 - v1.5.5-5 VS v1.5.5-6: https://github.com/luben/zstd-jni/compare/v.1.5.5-5...v1.5.5-6 2.Some changes include the following: - Add new method `getFrameContentSize` that will return also the error. https://github.com/luben/zstd-jni/commit/3f6c55e87a38255bdbeeca3e773f99bf9e8a7b7f - Fix error. https://github.com/luben/zstd-jni/commit/e0c66f1230377d4980ad6abc0767b36c860af538 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43113 from panbingkun/SPARK-45327. Authored-by: panbingkun Signed-off-by: yangjie01 --- .../ZStandardBenchmark-jdk21-results.txt | 32 +++++++++---------- .../benchmarks/ZStandardBenchmark-results.txt | 28 ++++++++-------- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 6 +++- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index 0915e73385686..19bd2faa9a5a5 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21+35 on Linux 5.15.0-1046-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1050-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 2649 2653 6 0.0 264919.3 1.0X -Compression 10000 times at level 2 without buffer pool 2785 2785 1 0.0 278455.9 1.0X -Compression 10000 times at level 3 without buffer pool 3078 3082 6 0.0 307845.7 0.9X -Compression 10000 times at level 1 with buffer pool 2353 2378 35 0.0 235340.6 1.1X -Compression 10000 times at level 2 with buffer pool 2462 2466 6 0.0 246194.4 1.1X -Compression 10000 times at level 3 with buffer pool 2761 2765 6 0.0 276095.6 1.0X +Compression 10000 times at level 1 without buffer pool 1817 1819 3 0.0 181709.6 1.0X +Compression 10000 times at level 2 without buffer pool 2081 2083 3 0.0 208053.7 0.9X +Compression 10000 times at level 3 without buffer pool 2288 2290 3 0.0 228795.4 0.8X +Compression 10000 times at level 1 with buffer pool 1997 1998 1 0.0 199686.9 0.9X +Compression 10000 times at level 2 with buffer pool 2062 2063 1 0.0 206209.3 0.9X +Compression 10000 times at level 3 with buffer pool 2243 2243 1 0.0 224271.8 0.8X -OpenJDK 64-Bit Server VM 21+35 on Linux 5.15.0-1046-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1050-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 2644 2644 0 0.0 264430.7 1.0X -Decompression 10000 times from level 2 without buffer pool 2597 2608 15 0.0 259715.2 1.0X -Decompression 10000 times from level 3 without buffer pool 2586 2600 20 0.0 258633.7 1.0X -Decompression 10000 times from level 1 with buffer pool 2290 2296 9 0.0 228957.6 1.2X -Decompression 10000 times from level 2 with buffer pool 2315 2319 6 0.0 231535.4 1.1X -Decompression 10000 times from level 3 with buffer pool 2283 2302 27 0.0 228308.3 1.2X +Decompression 10000 times from level 1 without buffer pool 1980 1981 1 0.0 197970.8 1.0X +Decompression 10000 times from level 2 without buffer pool 1978 1979 1 0.0 197813.4 1.0X +Decompression 10000 times from level 3 without buffer pool 1981 1983 2 0.0 198141.7 1.0X +Decompression 10000 times from level 1 with buffer pool 1825 1827 3 0.0 182475.2 1.1X +Decompression 10000 times from level 2 with buffer pool 1827 1827 0 0.0 182667.1 1.1X +Decompression 10000 times from level 3 with buffer pool 1826 1826 0 0.0 182579.8 1.1X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index 5299a52dc7b8a..6f7efdd6c94c5 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.8+7-LTS on Linux 5.15.0-1046-azure +OpenJDK 64-Bit Server VM 17.0.9+8-LTS on Linux 5.15.0-1050-azure Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 2800 2801 2 0.0 279995.2 1.0X -Compression 10000 times at level 2 without buffer pool 2832 2846 19 0.0 283227.4 1.0X -Compression 10000 times at level 3 without buffer pool 2978 3003 35 0.0 297782.6 0.9X -Compression 10000 times at level 1 with buffer pool 2650 2652 2 0.0 265042.4 1.1X -Compression 10000 times at level 2 with buffer pool 2684 2688 5 0.0 268419.4 1.0X -Compression 10000 times at level 3 with buffer pool 2811 2816 8 0.0 281069.3 1.0X +Compression 10000 times at level 1 without buffer pool 2786 2787 2 0.0 278560.1 1.0X +Compression 10000 times at level 2 without buffer pool 2831 2833 3 0.0 283091.0 1.0X +Compression 10000 times at level 3 without buffer pool 2958 2959 2 0.0 295806.3 0.9X +Compression 10000 times at level 1 with buffer pool 211 214 4 0.0 21145.3 13.2X +Compression 10000 times at level 2 with buffer pool 253 255 1 0.0 25328.1 11.0X +Compression 10000 times at level 3 with buffer pool 370 371 1 0.0 37046.1 7.5X -OpenJDK 64-Bit Server VM 17.0.8+7-LTS on Linux 5.15.0-1046-azure +OpenJDK 64-Bit Server VM 17.0.9+8-LTS on Linux 5.15.0-1050-azure Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 2747 2752 7 0.0 274669.2 1.0X -Decompression 10000 times from level 2 without buffer pool 2743 2748 8 0.0 274265.3 1.0X -Decompression 10000 times from level 3 without buffer pool 2743 2750 10 0.0 274344.4 1.0X -Decompression 10000 times from level 1 with buffer pool 2608 2608 0 0.0 260803.0 1.1X -Decompression 10000 times from level 2 with buffer pool 2608 2608 0 0.0 260804.2 1.1X -Decompression 10000 times from level 3 with buffer pool 2605 2607 3 0.0 260514.9 1.1X +Decompression 10000 times from level 1 without buffer pool 2745 2748 5 0.0 274454.0 1.0X +Decompression 10000 times from level 2 without buffer pool 2744 2745 1 0.0 274438.1 1.0X +Decompression 10000 times from level 3 without buffer pool 2746 2746 1 0.0 274586.0 1.0X +Decompression 10000 times from level 1 with buffer pool 2587 2588 1 0.0 258707.4 1.1X +Decompression 10000 times from level 2 with buffer pool 2586 2586 1 0.0 258566.8 1.1X +Decompression 10000 times from level 3 with buffer pool 2589 2589 0 0.0 258870.6 1.1X diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index b1c1ed8d81dae..6364ec48fb664 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -264,4 +264,4 @@ xz/1.9//xz-1.9.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar zookeeper/3.6.3//zookeeper-3.6.3.jar -zstd-jni/1.5.5-5//zstd-jni-1.5.5-5.jar +zstd-jni/1.5.5-7//zstd-jni-1.5.5-7.jar diff --git a/pom.xml b/pom.xml index 811d075e203a1..2e0c95516c177 100644 --- a/pom.xml +++ b/pom.xml @@ -799,7 +799,7 @@ com.github.luben zstd-jni - 1.5.5-5 + 1.5.5-7 com.clearspring.analytics @@ -2662,6 +2662,10 @@ javax.annotation javax.annotation-api + + com.github.luben + zstd-jni + From d6b53c34c5c586fe04e000929412d54383202c0f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Nov 2023 11:53:32 -0700 Subject: [PATCH 09/13] [SPARK-45754][CORE] Support `spark.deploy.appIdPattern` ### What changes were proposed in this pull request? This PR aims to support `spark.deploy.appIdPattern` for Apache Spark 4.0.0. ### Why are the changes needed? This allows the users to be able to control driver ID pattern. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs with the newly added test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43616 from dongjoon-hyun/SPARK-45754. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/master/Master.scala | 3 ++- .../org/apache/spark/internal/config/Deploy.scala | 9 +++++++++ .../apache/spark/deploy/master/MasterSuite.scala | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 0a66cc974da7c..058b944c591ad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -54,6 +54,7 @@ private[deploy] class Master( ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") private val driverIdPattern = conf.get(DRIVER_ID_PATTERN) + private val appIdPattern = conf.get(APP_ID_PATTERN) // For application IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) @@ -1152,7 +1153,7 @@ private[deploy] class Master( /** Generate a new app ID given an app's submission date */ private def newApplicationId(submitDate: Date): String = { - val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber) + val appId = appIdPattern.format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 appId } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index bffdc79175bd9..c6ccf9550bc91 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -90,4 +90,13 @@ private[spark] object Deploy { .stringConf .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") .createWithDefault("driver-%s-%04d") + + val APP_ID_PATTERN = ConfigBuilder("spark.deploy.appIdPattern") + .doc("The pattern for app ID generation based on Java `String.format` method.. " + + "The default value is `app-%s-%04d` which represents the existing app id string, " + + "e.g., `app-20231031224509-0008`. Plesae be careful to generate unique IDs.") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") + .createWithDefault("app-%s-%04d") } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index cef0e84f20f7a..e8615cdbdd559 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -803,6 +803,7 @@ class MasterSuite extends SparkFunSuite PrivateMethod[mutable.ArrayBuffer[DriverInfo]](Symbol("waitingDrivers")) private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) private val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) + private val _newApplicationId = PrivateMethod[String](Symbol("newApplicationId")) private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -1251,6 +1252,20 @@ class MasterSuite extends SparkFunSuite }.getMessage assert(m.contains("Whitespace is not allowed")) } + + test("SPARK-45754: Support app id pattern") { + val master = makeMaster(new SparkConf().set(APP_ID_PATTERN, "my-app-%2$05d")) + val submitDate = new Date() + assert(master.invokePrivate(_newApplicationId(submitDate)) === "my-app-00000") + assert(master.invokePrivate(_newApplicationId(submitDate)) === "my-app-00001") + } + + test("SPARK-45754: Prevent invalid app id patterns") { + val m = intercept[IllegalArgumentException] { + makeMaster(new SparkConf().set(APP_ID_PATTERN, "my app")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) From b14c1f036f8f394ad1903998128c05d04dd584a9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 1 Nov 2023 13:31:12 -0700 Subject: [PATCH 10/13] [SPARK-45763][CORE][UI] Improve `MasterPage` to show `Resource` column only when it exists MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to improve `MasterPage` to show `Resource` column only when it exists. ### Why are the changes needed? For non-GPU clusters, `Resource` column is empty always. ### Does this PR introduce _any_ user-facing change? After this PR, `MasterPage` still shows `Resource` column if the resource exists like the following. ![Screenshot 2023-11-01 at 11 02 43 AM](https://github.com/apache/spark/assets/9700541/104dd4e7-938b-4269-8952-512e8fb5fa39) If there is no resource on all workers, the `Resource` column is omitted. ![Screenshot 2023-11-01 at 11 03 20 AM](https://github.com/apache/spark/assets/9700541/12c9d4b2-330a-4e36-a6eb-ac2813e0649a) ### How was this patch tested? Manual test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43628 from dongjoon-hyun/SPARK-45763. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/deploy/master/ui/MasterPage.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 48c0c9601c14b..cb325b37958ec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -98,10 +98,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val state = getMasterState - val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory", "Resources") + val showResourceColumn = state.workers.filter(_.resourcesInfoUsed.nonEmpty).nonEmpty + val workerHeaders = if (showResourceColumn) { + Seq("Worker Id", "Address", "State", "Cores", "Memory", "Resources") + } else { + Seq("Worker Id", "Address", "State", "Cores", "Memory") + } val workers = state.workers.sortBy(_.id) val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) - val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) + val workerTable = UIUtils.listingTable(workerHeaders, workerRow(showResourceColumn), workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Resources Per Executor", "Submitted Time", "User", "State", "Duration") @@ -256,7 +261,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } - private def workerRow(worker: WorkerInfo): Seq[Node] = { + private def workerRow(showResourceColumn: Boolean): WorkerInfo => Seq[Node] = worker => { { @@ -276,7 +281,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(worker.memory)} ({Utils.megabytesToString(worker.memoryUsed)} Used) - {formatWorkerResourcesDetails(worker)} + {if (showResourceColumn) { + {formatWorkerResourcesDetails(worker)} + }} } From 59e291d36c4a9d956b993968a324359b3d75fe5f Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Thu, 2 Nov 2023 09:11:48 +0900 Subject: [PATCH 11/13] [SPARK-45680][CONNECT] Release session ### What changes were proposed in this pull request? Introduce a new `ReleaseSession` Spark Connect RPC, which cancels everything running in the session and removes the session server side. Refactor code around managing the cache of sessions into `SparkConnectSessionManager`. ### Why are the changes needed? Better session management. ### Does this PR introduce _any_ user-facing change? Not really. `SparkSession.stop()` API already existed on the client side. It was closing the client's network connection, but the Session was still there cached for 1 hour on the server side. Caveats, which were not really supported user behaviour: * After `session.stop()`, user could have created a new session with the same session_id in Configuration. That session would be a new session on the client side, but connect to the old cached session in the server. It could therefore e.g. access that old session's state like views or artifacts. * If a session timed out and was removed in the server, it used to be that a new request would re-create the session. The client would then see this as the old session, but the server would see a new one, and e.g. not have access to old session state that was removed. * User is no longer allowed to create a new session with the same session_id as before. ### How was this patch tested? Tests added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43546 from juliuszsompolski/release-session. Lead-authored-by: Juliusz Sompolski Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 5 + .../org/apache/spark/sql/SparkSession.scala | 8 + .../spark/sql/PlanGenerationTestSuite.scala | 4 +- .../apache/spark/sql/SparkSessionSuite.scala | 38 ++-- .../main/protobuf/spark/connect/base.proto | 30 +++ .../CustomSparkConnectBlockingStub.scala | 11 ++ .../connect/client/SparkConnectClient.scala | 10 + .../spark/sql/connect/config/Connect.scala | 18 ++ .../sql/connect/service/SessionHolder.scala | 79 +++++++- .../SparkConnectExecutionManager.scala | 23 ++- .../SparkConnectReleaseExecuteHandler.scala | 4 +- .../SparkConnectReleaseSessionHandler.scala | 40 ++++ .../connect/service/SparkConnectService.scala | 117 +++--------- .../service/SparkConnectSessionManager.scala | 177 ++++++++++++++++++ .../spark/sql/connect/utils/ErrorUtils.scala | 27 +-- .../sql/connect/SparkConnectServerTest.scala | 21 ++- .../execution/ReattachableExecuteSuite.scala | 4 + .../planner/SparkConnectServiceSuite.scala | 4 +- .../service/SparkConnectServiceE2ESuite.scala | 158 ++++++++++++++++ ...r-conditions-invalid-handle-error-class.md | 4 + python/pyspark/sql/connect/client/core.py | 23 ++- python/pyspark/sql/connect/proto/base_pb2.py | 42 +++-- python/pyspark/sql/connect/proto/base_pb2.pyi | 78 ++++++++ .../sql/connect/proto/base_pb2_grpc.py | 49 +++++ python/pyspark/sql/connect/session.py | 12 +- .../sql/tests/connect/test_connect_basic.py | 1 + 26 files changed, 819 insertions(+), 168 deletions(-) create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala create mode 100644 connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 278011b8cc8f4..af32bcf129c08 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1737,6 +1737,11 @@ "Session already exists." ] }, + "SESSION_CLOSED" : { + "message" : [ + "Session was closed." + ] + }, "SESSION_NOT_FOUND" : { "message" : [ "Session not found." diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 969ac017ecb1d..1cc1c8400fa89 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -665,6 +665,9 @@ class SparkSession private[sql] ( * @since 3.4.0 */ override def close(): Unit = { + if (releaseSessionOnClose) { + client.releaseSession() + } client.shutdown() allocator.close() SparkSession.onSessionClose(this) @@ -735,6 +738,11 @@ class SparkSession private[sql] ( * We null out the instance for now. */ private def writeReplace(): Any = null + + /** + * Set to false to prevent client.releaseSession on close() (testing only) + */ + private[sql] var releaseSessionOnClose = true } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index cf287088b59fb..5cc63bc45a04a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -120,7 +120,9 @@ class PlanGenerationTestSuite } override protected def afterAll(): Unit = { - session.close() + // Don't call client.releaseSession on close(), because the connection details are dummy. + session.releaseSessionOnClose = false + session.stop() if (cleanOrphanedGoldenFiles) { cleanOrphanedGoldenFile() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 4c858262c6ef5..8abc41639fdd2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -33,18 +33,24 @@ class SparkSessionSuite extends ConnectFunSuite { private val connectionString2: String = "sc://test.me:14099" private val connectionString3: String = "sc://doit:16845" + private def closeSession(session: SparkSession): Unit = { + // Don't call client.releaseSession on close(), because the connection details are dummy. + session.releaseSessionOnClose = false + session.close() + } + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") assert(session.client.configuration.port == 15002) - session.close() + closeSession(session) } test("remote") { val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) - session.close() + closeSession(session) } test("getOrCreate") { @@ -53,8 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite { try { assert(session1 eq session2) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -65,8 +71,8 @@ class SparkSessionSuite extends ConnectFunSuite { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -77,8 +83,8 @@ class SparkSessionSuite extends ConnectFunSuite { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -98,7 +104,7 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } - session.close() + closeSession(session) } test("Default/Active session") { @@ -136,12 +142,12 @@ class SparkSessionSuite extends ConnectFunSuite { assert(SparkSession.getActiveSession.contains(session1)) // Close session1 - session1.close() + closeSession(session1) assert(SparkSession.getDefaultSession.contains(session2)) assert(SparkSession.getActiveSession.isEmpty) // Close session2 - session2.close() + closeSession(session2) assert(SparkSession.getDefaultSession.isEmpty) assert(SparkSession.getActiveSession.isEmpty) } @@ -187,7 +193,7 @@ class SparkSessionSuite extends ConnectFunSuite { // Step 3 - close session 1, no more default session in both scripts phaser.arriveAndAwaitAdvance() - session1.close() + closeSession(session1) // Step 4 - no default session, same active session. phaser.arriveAndAwaitAdvance() @@ -240,13 +246,13 @@ class SparkSessionSuite extends ConnectFunSuite { // Step 7 - close active session in script2 phaser.arriveAndAwaitAdvance() - internalSession.close() + closeSession(internalSession) assert(SparkSession.getActiveSession.isEmpty) } assert(script1.get()) assert(script2.get()) assert(SparkSession.getActiveSession.contains(session2)) - session2.close() + closeSession(session2) assert(SparkSession.getActiveSession.isEmpty) } finally { executor.shutdown() @@ -254,13 +260,13 @@ class SparkSessionSuite extends ConnectFunSuite { } test("deprecated methods") { - SparkSession + val session = SparkSession .builder() .master("yayay") .appName("bob") .enableHiveSupport() .create() - .close() + closeSession(session) } test("serialize as null") { diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 27f51551ba921..19a94a5a429f0 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -784,6 +784,30 @@ message ReleaseExecuteResponse { optional string operation_id = 2; } +message ReleaseSessionRequest { + // (Required) + // + // The session_id of the request to reattach to. + // This must be an id of existing session. + string session_id = 1; + + // (Required) User context + // + // user_context.user_id and session+id both identify a unique remote spark session on the + // server side. + UserContext user_context = 2; + + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 3; +} + +message ReleaseSessionResponse { + // Session id of the session on which the release executed. + string session_id = 1; +} + message FetchErrorDetailsRequest { // (Required) @@ -934,6 +958,12 @@ service SparkConnectService { // RPC and ReleaseExecute may not be used. rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {} + // Release a session. + // All the executions in the session will be released. Any further requests for the session with + // that session_id for the given user_id will fail. If the session didn't exist or was already + // released, this is a noop. + rpc ReleaseSession(ReleaseSessionRequest) returns (ReleaseSessionResponse) {} + // FetchErrorDetails retrieves the matched exception with details based on a provided error id. rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {} } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index f2efa26f6b609..e963b4136160f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -96,6 +96,17 @@ private[connect] class CustomSparkConnectBlockingStub( } } + def releaseSession(request: ReleaseSessionRequest): ReleaseSessionResponse = { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { + retryHandler.retry { + stub.releaseSession(request) + } + } + } + def artifactStatus(request: ArtifactStatusesRequest): ArtifactStatusesResponse = { grpcExceptionConverter.convert( request.getSessionId, diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 42ace003da89f..6d3d9420e2263 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -243,6 +243,16 @@ private[sql] class SparkConnectClient( bstub.interrupt(request) } + private[sql] def releaseSession(): proto.ReleaseSessionResponse = { + val builder = proto.ReleaseSessionRequest.newBuilder() + val request = builder + .setUserContext(userContext) + .setSessionId(sessionId) + .setClientType(userAgent) + .build() + bstub.releaseSession(request) + } + private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] { override def childValue(parent: mutable.Set[String]): mutable.Set[String] = { // Note: make a clone such that changes in the parent tags aren't reflected in diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 2b3f218362cd3..1a5944676f5fb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -74,6 +74,24 @@ object Connect { .intConf .createWithDefault(1024) + val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT = + buildStaticConf("spark.connect.session.manager.defaultSessionTimeout") + .internal() + .doc("Timeout after which sessions without any new incoming RPC will be removed.") + .version("4.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("60m") + + val CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE = + buildStaticConf("spark.connect.session.manager.closedSessionsTombstonesSize") + .internal() + .doc( + "Maximum size of the cache of sessions after which sessions that did not receive any " + + "requests will be removed.") + .version("4.0.0") + .intConf + .createWithDefaultString("1000") + val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT = buildStaticConf("spark.connect.execute.manager.detachedTimeout") .internal() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index dcced21f37148..792012a682b28 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder -import org.apache.spark.{JobArtifactSet, SparkException} +import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -40,12 +40,19 @@ import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils +// Unique key identifying session by combination of user, and session id +case class SessionKey(userId: String, sessionId: String) + /** * Object used to hold the Spark Connect session state. */ case class SessionHolder(userId: String, sessionId: String, session: SparkSession) extends Logging { + @volatile private var lastRpcAccessTime: Option[Long] = None + + @volatile private var isClosing: Boolean = false + private val executions: ConcurrentMap[String, ExecuteHolder] = new ConcurrentHashMap[String, ExecuteHolder]() @@ -73,8 +80,21 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) - /** Add ExecuteHolder to this session. Called only by SparkConnectExecutionManager. */ + def key: SessionKey = SessionKey(userId, sessionId) + + /** + * Add ExecuteHolder to this session. + * + * Called only by SparkConnectExecutionManager under executionsLock. + */ private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = { + if (isClosing) { + // Do not accept new executions if the session is closing. + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> sessionId)) + } + val oldExecute = executions.putIfAbsent(executeHolder.operationId, executeHolder) if (oldExecute != null) { // the existence of this should alrady be checked by SparkConnectExecutionManager @@ -160,21 +180,55 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ def classloader: ClassLoader = artifactManager.classloader + private[connect] def updateAccessTime(): Unit = { + lastRpcAccessTime = Some(System.currentTimeMillis()) + } + + /** + * Initialize the session. + * + * Called only by SparkConnectSessionManager. + */ private[connect] def initializeSession(): Unit = { + updateAccessTime() eventManager.postStarted() } /** * Expire this session and trigger state cleanup mechanisms. + * + * Called only by SparkConnectSessionManager. */ - private[connect] def expireSession(): Unit = { - logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") + private[connect] def close(): Unit = { + logInfo(s"Closing session with userId: $userId and sessionId: $sessionId") + + // After isClosing=true, SessionHolder.addExecuteHolder() will not allow new executions for + // this session. Because both SessionHolder.addExecuteHolder() and + // SparkConnectExecutionManager.removeAllExecutionsForSession() are executed under + // executionsLock, this guarantees that removeAllExecutionsForSession triggered below will + // remove all executions and no new executions will be added in the meanwhile. + isClosing = true + + // Note on the below notes about concurrency: + // While closing the session can potentially race with operations started on the session, the + // intended use is that the client session will get closed when it's really not used anymore, + // or that it expires due to inactivity, in which case there should be no races. + + // Clean up all artifacts. + // Note: there can be concurrent AddArtifact calls still adding something. artifactManager.cleanUpResources() - eventManager.postClosed() - // Clean up running queries + + // Clean up running streaming queries. + // Note: there can be concurrent streaming queries being started. SparkConnectService.streamingSessionManager.cleanupRunningQueries(this) streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. removeAllListeners() // removes all listener and stop python listener processes if necessary. + + // Clean up all executions + // It is guaranteed at this point that no new addExecuteHolder are getting started. + SparkConnectService.executionManager.removeAllExecutionsForSession(this.key) + + eventManager.postClosed() } /** @@ -204,6 +258,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } + /** Get SessionInfo with information about this SessionHolder. */ + def getSessionHolderInfo: SessionHolderInfo = + SessionHolderInfo(userId, sessionId, eventManager.status, lastRpcAccessTime) + /** * Caches given DataFrame with the ID. The cache does not expire. The entry needs to be * explicitly removed by the owners of the DataFrame once it is not needed. @@ -291,7 +349,14 @@ object SessionHolder { userId = "testUser", sessionId = UUID.randomUUID().toString, session = session) - SparkConnectService.putSessionForTesting(ret) + SparkConnectService.sessionManager.putSessionForTesting(ret) ret } } + +/** Basic information about SessionHolder. */ +case class SessionHolderInfo( + userId: String, + sessionId: String, + status: SessionStatus, + lastRpcAccesTime: Option[Long]) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 3c72548978222..c004358e1cf18 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -95,11 +95,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * Remove an ExecuteHolder from this global manager and from its session. Interrupt the * execution if still running, free all resources. */ - private[connect] def removeExecuteHolder(key: ExecuteKey): Unit = { + private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = { var executeHolder: Option[ExecuteHolder] = None executionsLock.synchronized { executeHolder = executions.remove(key) - executeHolder.foreach(e => e.sessionHolder.removeExecuteHolder(e.operationId)) + executeHolder.foreach { e => + if (abandoned) { + abandonedTombstones.put(key, e.getExecuteInfo) + } + e.sessionHolder.removeExecuteHolder(e.operationId) + } if (executions.isEmpty) { lastExecutionTime = Some(System.currentTimeMillis()) } @@ -115,6 +120,17 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } } + private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { + val sessionExecutionHolders = executionsLock.synchronized { + executions.filter(_._2.sessionHolder.key == key) + } + sessionExecutionHolders.foreach { case (_, executeHolder) => + val info = executeHolder.getExecuteInfo + logInfo(s"Execution $info removed in removeSessionExecutions.") + removeExecuteHolder(executeHolder.key, abandoned = true) + } + } + /** Get info about abandoned execution, if there is one. */ private[connect] def getAbandonedTombstone(key: ExecuteKey): Option[ExecuteInfo] = { Option(abandonedTombstones.getIfPresent(key)) @@ -204,8 +220,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { toRemove.foreach { executeHolder => val info = executeHolder.getExecuteInfo logInfo(s"Found execution $info that was abandoned and expired and will be removed.") - removeExecuteHolder(executeHolder.key) - abandonedTombstones.put(executeHolder.key, info) + removeExecuteHolder(executeHolder.key, abandoned = true) } } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala index a3a7815609e40..1ca886960d536 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala @@ -28,8 +28,8 @@ class SparkConnectReleaseExecuteHandler( extends Logging { def handle(v: proto.ReleaseExecuteRequest): Unit = { - val sessionHolder = SparkConnectService - .getIsolatedSession(v.getUserContext.getUserId, v.getSessionId) + val sessionHolder = SparkConnectService.sessionManager + .getIsolatedSession(SessionKey(v.getUserContext.getUserId, v.getSessionId)) val responseBuilder = proto.ReleaseExecuteResponse .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala new file mode 100644 index 0000000000000..a32852bac45ea --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging + +class SparkConnectReleaseSessionHandler( + responseObserver: StreamObserver[proto.ReleaseSessionResponse]) + extends Logging { + + def handle(v: proto.ReleaseSessionRequest): Unit = { + val responseBuilder = proto.ReleaseSessionResponse.newBuilder() + responseBuilder.setSessionId(v.getSessionId) + + // If the session doesn't exist, this will just be a noop. + val key = SessionKey(v.getUserContext.getUserId, v.getSessionId) + SparkConnectService.sessionManager.closeSession(key) + + responseObserver.onNext(responseBuilder.build()) + responseObserver.onCompleted() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e82c9cba56264..e4b60eeeff0d6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -18,13 +18,10 @@ package org.apache.spark.sql.connect.service import java.net.InetSocketAddress -import java.util.UUID -import java.util.concurrent.{Callable, TimeUnit} +import java.util.concurrent.TimeUnit import scala.jdk.CollectionConverters._ -import com.google.common.base.Ticker -import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} import com.google.protobuf.MessageLite import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition} import io.grpc.MethodDescriptor.PrototypeMarshaller @@ -34,13 +31,12 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException} +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI.UI_ENABLED -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -201,6 +197,22 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ sessionId = request.getSessionId) } + /** + * Release session. + */ + override def releaseSession( + request: proto.ReleaseSessionRequest, + responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { + try { + new SparkConnectReleaseSessionHandler(responseObserver).handle(request) + } catch + ErrorUtils.handleError( + "releaseSession", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } + override def fetchErrorDetails( request: proto.FetchErrorDetailsRequest, responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit = { @@ -268,14 +280,6 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ */ object SparkConnectService extends Logging { - private val CACHE_SIZE = 100 - - private val CACHE_TIMEOUT_SECONDS = 3600 - - // Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a - // different or complex type easily. - private type SessionCacheKey = (String, String) - private[connect] var server: Server = _ private[connect] var uiTab: Option[SparkConnectServerTab] = None @@ -289,77 +293,18 @@ object SparkConnectService extends Logging { server.getPort } - private val userSessionMapping = - cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, SessionHolder]() - private[connect] lazy val executionManager = new SparkConnectExecutionManager() + private[connect] lazy val sessionManager = new SparkConnectSessionManager() + private[connect] val streamingSessionManager = new SparkConnectStreamingQueryCache() - private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] { - override def onRemoval( - notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = { - notification.getValue.expireSession() - } - } - - // Simple builder for creating the cache of Sessions. - private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): CacheBuilder[Object, Object] = { - var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker()) - if (cacheSize >= 0) { - cacheBuilder = cacheBuilder.maximumSize(cacheSize) - } - if (timeoutSeconds >= 0) { - cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS) - } - cacheBuilder.removalListener(new RemoveSessionListener) - cacheBuilder - } - /** * Based on the userId and sessionId, find or create a new SparkSession. */ def getOrCreateIsolatedSession(userId: String, sessionId: String): SessionHolder = { - getSessionOrDefault( - userId, - sessionId, - () => { - val holder = SessionHolder(userId, sessionId, newIsolatedSession()) - holder.initializeSession() - holder - }) - } - - /** - * Based on the userId and sessionId, find an existing SparkSession or throw error. - */ - def getIsolatedSession(userId: String, sessionId: String): SessionHolder = { - getSessionOrDefault( - userId, - sessionId, - () => { - logDebug(s"Session not found: ($userId, $sessionId)") - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND", - messageParameters = Map("handle" -> sessionId)) - }) - } - - private def getSessionOrDefault( - userId: String, - sessionId: String, - default: Callable[SessionHolder]): SessionHolder = { - // Validate that sessionId is formatted like UUID before creating session. - try { - UUID.fromString(sessionId).toString - } catch { - case _: IllegalArgumentException => - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.FORMAT", - messageParameters = Map("handle" -> sessionId)) - } - userSessionMapping.get((userId, sessionId), default) + sessionManager.getOrCreateIsolatedSession(SessionKey(userId, sessionId)) } /** @@ -368,24 +313,6 @@ object SparkConnectService extends Logging { */ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionManager.listActiveExecutions - /** - * Used for testing - */ - private[connect] def invalidateAllSessions(): Unit = { - userSessionMapping.invalidateAll() - } - - /** - * Used for testing. - */ - private[connect] def putSessionForTesting(sessionHolder: SessionHolder): Unit = { - userSessionMapping.put((sessionHolder.userId, sessionHolder.sessionId), sessionHolder) - } - - private def newIsolatedSession(): SparkSession = { - SparkSession.active.newSession() - } - private def createListenerAndUI(sc: SparkContext): Unit = { val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore] listener = new SparkConnectServerListener(kvStore, sc.conf) @@ -445,7 +372,7 @@ object SparkConnectService extends Logging { } streamingSessionManager.shutdown() executionManager.shutdown() - userSessionMapping.invalidateAll() + sessionManager.shutdown() uiTab.foreach(_.detach()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala new file mode 100644 index 0000000000000..5c8e3c611586c --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.UUID +import java.util.concurrent.{Callable, TimeUnit} + +import com.google.common.base.Ticker +import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} + +import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE, CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT} + +/** + * Global tracker of all SessionHolders holding Spark Connect sessions. + */ +class SparkConnectSessionManager extends Logging { + + private val sessionsLock = new Object + + private val sessionStore = + CacheBuilder + .newBuilder() + .ticker(Ticker.systemTicker()) + .expireAfterAccess( + SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT), + TimeUnit.MILLISECONDS) + .removalListener(new RemoveSessionListener) + .build[SessionKey, SessionHolder]() + + private val closedSessionsCache = + CacheBuilder + .newBuilder() + .maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE)) + .build[SessionKey, SessionHolderInfo]() + + /** + * Based on the userId and sessionId, find or create a new SparkSession. + */ + private[connect] def getOrCreateIsolatedSession(key: SessionKey): SessionHolder = { + // Lock to guard against concurrent removal and insertion into closedSessionsCache. + sessionsLock.synchronized { + getSession( + key, + Some(() => { + validateSessionCreate(key) + val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession()) + holder.initializeSession() + holder + })) + } + } + + /** + * Based on the userId and sessionId, find an existing SparkSession or throw error. + */ + private[connect] def getIsolatedSession(key: SessionKey): SessionHolder = { + getSession( + key, + Some(() => { + logDebug(s"Session not found: $key") + if (closedSessionsCache.getIfPresent(key) != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> key.sessionId)) + } else { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND", + messageParameters = Map("handle" -> key.sessionId)) + } + })) + } + + /** + * Based on the userId and sessionId, get an existing SparkSession if present. + */ + private[connect] def getIsolatedSessionIfPresent(key: SessionKey): Option[SessionHolder] = { + Option(getSession(key, None)) + } + + private def getSession( + key: SessionKey, + default: Option[Callable[SessionHolder]]): SessionHolder = { + val session = default match { + case Some(callable) => sessionStore.get(key, callable) + case None => sessionStore.getIfPresent(key) + } + // record access time before returning + session match { + case null => + null + case s: SessionHolder => + s.updateAccessTime() + s + } + } + + def closeSession(key: SessionKey): Unit = { + // Invalidate will trigger RemoveSessionListener + sessionStore.invalidate(key) + } + + private class RemoveSessionListener extends RemovalListener[SessionKey, SessionHolder] { + override def onRemoval(notification: RemovalNotification[SessionKey, SessionHolder]): Unit = { + val sessionHolder = notification.getValue + sessionsLock.synchronized { + // First put into closedSessionsCache, so that it cannot get accidentally recreated by + // getOrCreateIsolatedSession. + closedSessionsCache.put(sessionHolder.key, sessionHolder.getSessionHolderInfo) + } + // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by + // getOrCreateIsolatedSession. + sessionHolder.close() + } + } + + def shutdown(): Unit = { + sessionsLock.synchronized { + sessionStore.invalidateAll() + closedSessionsCache.invalidateAll() + } + } + + private def newIsolatedSession(): SparkSession = { + SparkSession.active.newSession() + } + + private def validateSessionCreate(key: SessionKey): Unit = { + // Validate that sessionId is formatted like UUID before creating session. + try { + UUID.fromString(key.sessionId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> key.sessionId)) + } + // Validate that session with that key has not been already closed. + if (closedSessionsCache.getIfPresent(key) != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> key.sessionId)) + } + } + + /** + * Used for testing + */ + private[connect] def invalidateAllSessions(): Unit = { + sessionStore.invalidateAll() + closedSessionsCache.invalidateAll() + } + + /** + * Used for testing. + */ + private[connect] def putSessionForTesting(sessionHolder: SessionHolder): Unit = { + sessionStore.put(sessionHolder.key, sessionHolder) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 741fa97f17878..837ee5a00227c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -41,7 +41,7 @@ import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto.FetchErrorDetailsResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect -import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SparkConnectService} +import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.internal.SQLConf private[connect] object ErrorUtils extends Logging { @@ -153,7 +153,9 @@ private[connect] object ErrorUtils extends Logging { .build() } - private def buildStatusFromThrowable(st: Throwable, sessionHolder: SessionHolder): RPCStatus = { + private def buildStatusFromThrowable( + st: Throwable, + sessionHolderOpt: Option[SessionHolder]): RPCStatus = { val errorInfo = ErrorInfo .newBuilder() .setReason(st.getClass.getName) @@ -162,20 +164,20 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) - if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) { + if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString errorInfo.putMetadata("errorId", errorId) - sessionHolder.errorIdToError + sessionHolderOpt.get.errorIdToError .put(errorId, st) } lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) val withStackTrace = - if (sessionHolder.session.conf.get( - SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) { + if (sessionHolderOpt.exists( + _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) { val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) } else { @@ -215,19 +217,22 @@ private[connect] object ErrorUtils extends Logging { sessionId: String, events: Option[ExecuteEventsManager] = None, isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { - val sessionHolder = - SparkConnectService - .getOrCreateIsolatedSession(userId, sessionId) + + // SessionHolder may not be present, e.g. if the session was already closed. + // When SessionHolder is not present error details will not be available for FetchErrorDetails. + val sessionHolderOpt = + SparkConnectService.sessionManager.getIsolatedSessionIfPresent( + SessionKey(userId, sessionId)) val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => ( se, StatusProto.toStatusRuntimeException( - buildStatusFromThrowable(se.getCause, sessionHolder))) + buildStatusFromThrowable(se.getCause, sessionHolderOpt))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolder))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolderOpt))) case e: Throwable => ( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 7b02377f4847c..120126f20ec24 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -59,10 +59,6 @@ trait SparkConnectServerTest extends SharedSparkSession { withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { SparkConnectService.start(spark.sparkContext) } - // register udf directly on the server, we're not testing client UDFs here... - val serverSession = - SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session - serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) } override def afterAll(): Unit = { @@ -84,6 +80,7 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def clearAllExecutions(): Unit = { SparkConnectService.executionManager.listExecuteHolders.foreach(_.close()) SparkConnectService.executionManager.periodicMaintenance(0) + SparkConnectService.sessionManager.invalidateAllSessions() assertNoActiveExecutions() } @@ -215,12 +212,24 @@ trait SparkConnectServerTest extends SharedSparkSession { } } + protected def withClient(sessionId: String = defaultSessionId, userId: String = defaultUserId)( + f: SparkConnectClient => Unit): Unit = { + withClient(f, sessionId, userId) + } + protected def withClient(f: SparkConnectClient => Unit): Unit = { + withClient(f, defaultSessionId, defaultUserId) + } + + protected def withClient( + f: SparkConnectClient => Unit, + sessionId: String, + userId: String): Unit = { val client = SparkConnectClient .builder() .port(serverPort) - .sessionId(defaultSessionId) - .userId(defaultUserId) + .sessionId(sessionId) + .userId(userId) .enableReattachableExecute() .build() try f(client) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 0e29a07b719af..784b978f447df 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -347,6 +347,10 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { } test("long sleeping query") { + // register udf directly on the server, we're not testing client UDFs here... + val serverSession = + SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session + serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) // query will be sleeping and not returning results, while having multiple reattach withSparkEnvConfs( (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, "1s")) { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index ce452623e6b84..b314e7d8d4834 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -841,12 +841,12 @@ class SparkConnectServiceSuite spark.sparkContext.addSparkListener(verifyEvents.listener) Utils.tryWithSafeFinally({ f(verifyEvents) - SparkConnectService.invalidateAllSessions() + SparkConnectService.sessionManager.invalidateAllSessions() verifyEvents.onSessionClosed() }) { verifyEvents.waitUntilEmpty() spark.sparkContext.removeSparkListener(verifyEvents.listener) - SparkConnectService.invalidateAllSessions() + SparkConnectService.sessionManager.invalidateAllSessions() SparkConnectPluginRegistry.reset() } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index 14ecc9a2e95e4..cc0481dab0f4f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -16,13 +16,171 @@ */ package org.apache.spark.sql.connect.service +import java.util.UUID + import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.SparkConnectServerTest class SparkConnectServiceE2ESuite extends SparkConnectServerTest { + // Making results of these queries large enough, so that all the results do not fit in the + // buffers and are not pushed out immediately even when the client doesn't consume them, so that + // even if the connection got closed, the client would see it as succeeded because the results + // were all already in the buffer. + val BIG_ENOUGH_QUERY = "select * from range(1000000)" + + test("ReleaseSession releases all queries and does not allow more requests in the session") { + withClient { client => + val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + val query2 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + val query3 = client.execute(buildPlan("select 1")) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + query1.hasNext + query2.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 2 + } + + // Close session + client.releaseSession() + + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 0 + // SparkConnectService.sessionManager. + } + + // query1 and query2 could get either an: + // OPERATION_CANCELED if it happens fast - when closing the session interrupted the queries, + // and that error got pushed to the client buffers before the client got disconnected. + // OPERATION_ABANDONED if it happens slow - when closing the session interrupted the client + // RPCs before it pushed out the error above. The client would then get an + // INVALID_CURSOR.DISCONNECTED, which it will retry with a ReattachExecute, and then get an + // INVALID_HANDLE.OPERATION_ABANDONED. + val query1Error = intercept[SparkException] { + while (query1.hasNext) query1.next() + } + assert( + query1Error.getMessage.contains("OPERATION_CANCELED") || + query1Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + val query2Error = intercept[SparkException] { + while (query2.hasNext) query2.next() + } + assert( + query2Error.getMessage.contains("OPERATION_CANCELED") || + query2Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + // query3 has not been submitted before, so it should now fail with SESSION_CLOSED + val query3Error = intercept[SparkException] { + query3.hasNext + } + assert(query3Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + + // No other requests should be allowed in the session, failing with SESSION_CLOSED + val requestError = intercept[SparkException] { + client.interruptAll() + } + assert(requestError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + + private def testReleaseSessionTwoSessions( + sessionIdA: String, + userIdA: String, + sessionIdB: String, + userIdB: String): Unit = { + withClient(sessionId = sessionIdA, userId = userIdA) { clientA => + withClient(sessionId = sessionIdB, userId = userIdB) { clientB => + val queryA = clientA.execute(buildPlan(BIG_ENOUGH_QUERY)) + val queryB = clientB.execute(buildPlan(BIG_ENOUGH_QUERY)) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + queryA.hasNext + queryB.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 2 + } + // Close session A + clientA.releaseSession() + + // A's query gets kicked out. + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 1 + } + val queryAError = intercept[SparkException] { + while (queryA.hasNext) queryA.next() + } + assert( + queryAError.getMessage.contains("OPERATION_CANCELED") || + queryAError.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + // B's query can run. + while (queryB.hasNext) queryB.next() + + // B can submit more queries. + val queryB2 = clientB.execute(buildPlan("SELECT 1")) + while (queryB2.hasNext) queryB2.next() + // A can't submit more queries. + val queryA2 = clientA.execute(buildPlan("SELECT 1")) + val queryA2Error = intercept[SparkException] { + clientA.interruptAll() + } + assert(queryA2Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + } + + test("ReleaseSession for different user_id with same session_id do not affect each other") { + testReleaseSessionTwoSessions(defaultSessionId, "A", defaultSessionId, "B") + } + + test("ReleaseSession for different session_id with same user_id do not affect each other") { + val sessionIdA = UUID.randomUUID.toString() + val sessionIdB = UUID.randomUUID.toString() + testReleaseSessionTwoSessions(sessionIdA, "X", sessionIdB, "X") + } + + test("ReleaseSession: can't create a new session with the same id and user after release") { + val sessionId = UUID.randomUUID.toString() + val userId = "Y" + withClient(sessionId = sessionId, userId = userId) { client => + // this will create the session, and then ReleaseSession at the end of withClient. + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = sessionId, userId = userId) { client => + // shall not be able to create a new session with the same id and user. + val query = client.execute(buildPlan("SELECT 1")) + val queryError = intercept[SparkException] { + while (query.hasNext) query.next() + } + assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + + test("ReleaseSession: session with different session_id or user_id allowed after release") { + val sessionId = UUID.randomUUID.toString() + val userId = "Y" + withClient(sessionId = sessionId, userId = userId) { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = UUID.randomUUID.toString, userId = userId) { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = sessionId, userId = "YY") { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + } + test("SPARK-45133 query should reach FINISHED state when results are not consumed") { withRawBlockingStub { stub => val iter = diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md b/docs/sql-error-conditions-invalid-handle-error-class.md index c4cbb48035ff5..14526cd53724f 100644 --- a/docs/sql-error-conditions-invalid-handle-error-class.md +++ b/docs/sql-error-conditions-invalid-handle-error-class.md @@ -45,6 +45,10 @@ Operation not found. Session already exists. +## SESSION_CLOSED + +Session was closed. + ## SESSION_NOT_FOUND Session not found. diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 318f7d7ade4a2..11a1112ad1fe7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -19,7 +19,7 @@ "SparkConnectClient", ] -from pyspark.loose_version import LooseVersion + from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -61,6 +61,7 @@ from google.protobuf import text_format from google.rpc import error_details_pb2 +from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager @@ -1471,6 +1472,26 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]: except Exception as error: self._handle_error(error) + def release_session(self) -> None: + req = pb2.ReleaseSessionRequest() + req.session_id = self._session_id + req.client_type = self._builder.userAgent + if self._user_id: + req.user_context.user_id = self._user_id + try: + for attempt in self._retrying(): + with attempt: + resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata()) + if resp.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request:" + f"{resp.session_id} != {self._session_id}" + ) + return + raise SparkConnectException("Invalid state during retry exception handling.") + except Exception as error: + self._handle_error(error) + def add_tag(self, tag: str) -> None: self._throw_if_invalid_tag(tag) if not hasattr(self.thread_local, "tags"): diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 0ea02525f78ff..0e374e7aa2ccb 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"7\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -199,22 +199,26 @@ _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234 _RELEASEEXECUTERESPONSE._serialized_start = 11263 _RELEASEEXECUTERESPONSE._serialized_end = 11375 - _FETCHERRORDETAILSREQUEST._serialized_start = 11378 - _FETCHERRORDETAILSREQUEST._serialized_end = 11579 - _FETCHERRORDETAILSRESPONSE._serialized_start = 11582 - _FETCHERRORDETAILSRESPONSE._serialized_end = 13052 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12271 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12234 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12271 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12274 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12683 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12585 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12653 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12686 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13033 - _SPARKCONNECTSERVICE._serialized_start = 13055 - _SPARKCONNECTSERVICE._serialized_end = 13904 + _RELEASESESSIONREQUEST._serialized_start = 11378 + _RELEASESESSIONREQUEST._serialized_end = 11549 + _RELEASESESSIONRESPONSE._serialized_start = 11551 + _RELEASESESSIONRESPONSE._serialized_end = 11606 + _FETCHERRORDETAILSREQUEST._serialized_start = 11609 + _FETCHERRORDETAILSREQUEST._serialized_end = 11810 + _FETCHERRORDETAILSRESPONSE._serialized_start = 11813 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13283 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11958 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12132 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12135 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12502 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12465 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12502 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12505 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12914 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12816 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12884 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12917 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13264 + _SPARKCONNECTSERVICE._serialized_start = 13286 + _SPARKCONNECTSERVICE._serialized_end = 14232 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index c29feb4164cf1..20abbcb348bdd 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2763,6 +2763,84 @@ class ReleaseExecuteResponse(google.protobuf.message.Message): global___ReleaseExecuteResponse = ReleaseExecuteResponse +class ReleaseSessionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + USER_CONTEXT_FIELD_NUMBER: builtins.int + CLIENT_TYPE_FIELD_NUMBER: builtins.int + session_id: builtins.str + """(Required) + + The session_id of the request to reattach to. + This must be an id of existing session. + """ + @property + def user_context(self) -> global___UserContext: + """(Required) User context + + user_context.user_id and session+id both identify a unique remote spark session on the + server side. + """ + client_type: builtins.str + """Provides optional information about the client sending the request. This field + can be used for language or version specific information and is only intended for + logging purposes and will not be interpreted by the server. + """ + def __init__( + self, + *, + session_id: builtins.str = ..., + user_context: global___UserContext | None = ..., + client_type: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "user_context", + b"user_context", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "session_id", + b"session_id", + "user_context", + b"user_context", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] + ) -> typing_extensions.Literal["client_type"] | None: ... + +global___ReleaseSessionRequest = ReleaseSessionRequest + +class ReleaseSessionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + session_id: builtins.str + """Session id of the session on which the release executed.""" + def __init__( + self, + *, + session_id: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["session_id", b"session_id"] + ) -> None: ... + +global___ReleaseSessionResponse = ReleaseSessionResponse + class FetchErrorDetailsRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index f6c5573ded6b5..12675747e0f92 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -70,6 +70,11 @@ def __init__(self, channel): request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString, ) + self.ReleaseSession = channel.unary_unary( + "/spark.connect.SparkConnectService/ReleaseSession", + request_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + ) self.FetchErrorDetails = channel.unary_unary( "/spark.connect.SparkConnectService/FetchErrorDetails", request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString, @@ -141,6 +146,16 @@ def ReleaseExecute(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def ReleaseSession(self, request, context): + """Release a session. + All the executions in the session will be released. Any further requests for the session with + that session_id for the given user_id will fail. If the session didn't exist or was already + released, this is a noop. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def FetchErrorDetails(self, request, context): """FetchErrorDetails retrieves the matched exception with details based on a provided error id.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -190,6 +205,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString, response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString, ), + "ReleaseSession": grpc.unary_unary_rpc_method_handler( + servicer.ReleaseSession, + request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.SerializeToString, + ), "FetchErrorDetails": grpc.unary_unary_rpc_method_handler( servicer.FetchErrorDetails, request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString, @@ -438,6 +458,35 @@ def ReleaseExecute( metadata, ) + @staticmethod + def ReleaseSession( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/spark.connect.SparkConnectService/ReleaseSession", + spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def FetchErrorDetails( request, diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 09bd60606c769..1aa857b4f6175 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -254,6 +254,9 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + # Set to false to prevent client.release_session on close() (testing only) + self.release_session_on_close = True + @classmethod def _set_default_and_active_session(cls, session: "SparkSession") -> None: """ @@ -645,15 +648,16 @@ def clearTags(self) -> None: clearTags.__doc__ = PySparkSession.clearTags.__doc__ def stop(self) -> None: - # Stopping the session will only close the connection to the current session (and - # the life cycle of the session is maintained by the server), - # whereas the regular PySpark session immediately terminates the Spark Context - # itself, meaning that stopping all Spark sessions. + # Whereas the regular PySpark session immediately terminates the Spark Context + # itself, meaning that stopping all Spark sessions, this will only stop this one session + # on the server. # It is controversial to follow the existing the regular Spark session's behavior # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. with SparkSession._lock: + if not self.is_stopped and self.release_session_on_close: + self.client.release_session() self.client.close() if self is SparkSession._default_session: SparkSession._default_session = None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 34bd314c76f7c..f024a03c2686c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3437,6 +3437,7 @@ def test_can_create_multiple_sessions_to_different_remotes(self): # Gets currently active session. same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate() self.assertEquals(other, same) + same.release_session_on_close = False # avoid sending release to dummy connection same.stop() # Make sure the environment is clean. From a04d4e2233c0d20c6a86c64391b1e1a6071b4550 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 2 Nov 2023 09:18:42 +0900 Subject: [PATCH 12/13] [SPARK-45761][K8S][INFRA][DOCS] Upgrade `Volcano` to 1.8.1 ### What changes were proposed in this pull request? This PR aims to upgrade `Volcano` to 1.8.1 in K8s integration test document and GitHub Action job. ### Why are the changes needed? To bring the latest feature and bug fixes in addition to the test coverage for Volcano scheduler 1.8.1. - https://github.com/volcano-sh/volcano/releases/tag/v1.8.1 - https://github.com/volcano-sh/volcano/pull/3101 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43624 from dongjoon-hyun/SPARK-45761. Authored-by: Dongjoon Hyun Signed-off-by: Hyukjin Kwon --- .github/workflows/build_and_test.yml | 2 +- resource-managers/kubernetes/integration-tests/README.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 5825185f34450..eded5da5c1ddd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1063,7 +1063,7 @@ jobs: export PVC_TESTS_VM_PATH=$PVC_TMP_DIR minikube mount ${PVC_TESTS_HOST_PATH}:${PVC_TESTS_VM_PATH} --gid=0 --uid=185 & kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml || true + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml || true eval $(minikube docker-env) build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index d39fdfbfd3966..d5ccd3fe756b7 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -329,11 +329,11 @@ You can also specify your specific dockerfile to build JVM/Python/R based image ## Requirements - A minimum of 6 CPUs and 9G of memory is required to complete all Volcano test cases. -- Volcano v1.8.0. +- Volcano v1.8.1. ## Installation - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml ## Run tests @@ -354,5 +354,5 @@ You can also specify `volcano` tag to only run Volcano test: ## Cleanup Volcano - kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml + kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml From 30ec6e358536dfb695fcc1b8c3f084acb576d871 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Wed, 1 Nov 2023 21:08:04 -0700 Subject: [PATCH 13/13] [SPARK-45742][CORE][CONNECT][MLLIB][PYTHON] Introduce an implicit function for Scala Array to wrap into `immutable.ArraySeq` ### What changes were proposed in this pull request? Currently, we need to use `immutable.ArraySeq.unsafeWrapArray(array)` to wrap an Array into an `immutable.ArraySeq`, which makes the code look bloated. So this PR introduces an implicit function `toImmutableArraySeq` to make it easier for Scala Array to be wrapped into `immutable.ArraySeq`. After this pr, we can use the following way to wrap an array into an `immutable.ArraySeq`: ```scala import org.apache.spark.util.ArrayImplicits._ val dataArray = ... val immutableArraySeq = dataArray.toImmutableArraySeq ``` At the same time, this pr replaces the existing use of `immutable.ArraySeq.unsafeWrapArray(array)` with the new method. On the other hand, this implicit function will be conducive to the progress of work SPARK-45686 and SPARK-45687. ### Why are the changes needed? Makes the code for wrapping a Scala Array into an `immutable.ArraySeq` look less bloated. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #43607 from LuciferYang/SPARK-45742. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../apache/spark/util/ArrayImplicits.scala | 36 +++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 4 +- .../client/GrpcExceptionConverter.scala | 4 +- .../connect/planner/SparkConnectPlanner.scala | 27 +++++----- .../spark/sql/connect/utils/ErrorUtils.scala | 32 ++++++------ .../spark/util/ArrayImplicitsSuite.scala | 50 +++++++++++++++++++ .../python/GaussianMixtureModelWrapper.scala | 4 +- .../mllib/api/python/LDAModelWrapper.scala | 8 +-- 8 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala create mode 100644 core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala diff --git a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala new file mode 100644 index 0000000000000..08997a800c957 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +/** + * Implicit methods related to Scala Array. + */ +private[spark] object ArrayImplicits { + + implicit class SparkArrayOps[T](xs: Array[T]) { + + /** + * Wraps an Array[T] as an immutable.ArraySeq[T] without copying. + */ + def toImmutableArraySeq: immutable.ArraySeq[T] = + if (xs eq null) null + else immutable.ArraySeq.unsafeWrapArray(xs) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1cc1c8400fa89..34756f9a440bb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.net.URI import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicLong, AtomicReference} -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * The entry point to programming Spark with the Dataset and DataFrame API. @@ -248,7 +248,7 @@ class SparkSession private[sql] ( proto.SqlCommand .newBuilder() .setSql(sqlText) - .addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava))) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) // .toBuffer forces that the iterator is consumed and closed val responseSeq = client.execute(plan.build()).toBuffer.toSeq diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 3e53722caeb07..652797bc2e40f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client import java.time.DateTimeException -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -37,6 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.streaming.StreamingQueryException +import org.apache.spark.util.ArrayImplicits._ /** * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions into Spark exceptions. @@ -375,7 +375,7 @@ private[client] object GrpcExceptionConverter { FetchErrorDetailsResponse.Error .newBuilder() .setMessage(message) - .addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava) + .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) .build())) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index ec57909ad144e..018e293795e9d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connect.planner -import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.Try @@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.CacheId +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils final case class InvalidCommandInput( @@ -3184,9 +3184,9 @@ class SparkConnectPlanner( case StreamingQueryManagerCommand.CommandCase.ACTIVE => val active_queries = session.streams.active respBuilder.getActiveBuilder.addAllActiveQueries( - immutable.ArraySeq - .unsafeWrapArray(active_queries - .map(query => buildStreamingQueryInstance(query))) + active_queries + .map(query => buildStreamingQueryInstance(query)) + .toImmutableArraySeq .asJava) case StreamingQueryManagerCommand.CommandCase.GET_QUERY => @@ -3265,15 +3265,16 @@ class SparkConnectPlanner( .setGetResourcesCommandResult( proto.GetResourcesCommandResult .newBuilder() - .putAllResources(session.sparkContext.resources.view - .mapValues(resource => - proto.ResourceInformation - .newBuilder() - .setName(resource.name) - .addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava) - .build()) - .toMap - .asJava) + .putAllResources( + session.sparkContext.resources.view + .mapValues(resource => + proto.ResourceInformation + .newBuilder() + .setName(resource.name) + .addAllAddresses(resource.addresses.toImmutableArraySeq.asJava) + .build()) + .toMap + .asJava) .build()) .build()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 837ee5a00227c..744fa3c8aa1a4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils import java.util.UUID import scala.annotation.tailrec -import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -43,6 +42,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ private[connect] object ErrorUtils extends Logging { @@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging { if (serverStackTraceEnabled) { builder.addAllStackTrace( - immutable.ArraySeq - .unsafeWrapArray(currentError.getStackTrace - .map { stackTraceElement => - val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement - .newBuilder() - .setDeclaringClass(stackTraceElement.getClassName) - .setMethodName(stackTraceElement.getMethodName) - .setLineNumber(stackTraceElement.getLineNumber) - - if (stackTraceElement.getFileName != null) { - stackTraceBuilder.setFileName(stackTraceElement.getFileName) - } - - stackTraceBuilder.build() - }) + currentError.getStackTrace + .map { stackTraceElement => + val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement + .newBuilder() + .setDeclaringClass(stackTraceElement.getClassName) + .setMethodName(stackTraceElement.getMethodName) + .setLineNumber(stackTraceElement.getLineNumber) + + if (stackTraceElement.getFileName != null) { + stackTraceBuilder.setFileName(stackTraceElement.getFileName) + } + + stackTraceBuilder.build() + } + .toImmutableArraySeq .asJava) } diff --git a/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala new file mode 100644 index 0000000000000..135af550c4b39 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ArrayImplicits._ + +class ArrayImplicitsSuite extends SparkFunSuite { + + test("Int Array") { + val data = Array(1, 2, 3) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("TestClass Array") { + val data = Array(TestClass(1), TestClass(2), TestClass(3)) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("Null Array") { + val data: Array[Int] = null + val arraySeq = data.toImmutableArraySeq + assert(arraySeq == null) + } + + case class TestClass(i: Int) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 1eed97a8d4f65..2f3f396730be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val modelGaussians = model.gaussians.map { gaussian => Array[Any](gaussian.mu, gaussian.sigma) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava) + SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava) } def predictSoft(point: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala index b919b0a8c3f2e..6a6c6cf6bcfb3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -16,12 +16,12 @@ */ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.LDAModel import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around LDAModel to provide helper methods in Python @@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) { def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => - val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava - val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava + val jTerms = terms.toImmutableArraySeq.asJava + val jTermWeights = termWeights.toImmutableArraySeq.asJava Array[Any](jTerms, jTermWeights) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava) + SerDe.dumps(topics.toImmutableArraySeq.asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path)