From 630b1777904f15c7ac05c3cd61c0006cd692bc93 Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Tue, 8 Aug 2023 11:11:56 +0900 Subject: [PATCH 01/30] [SPARK-44683][SS] Logging level isn't passed to RocksDB state store provider correctly ### What changes were proposed in this pull request? The logging level is passed into RocksDB in a correct way. ### Why are the changes needed? We pass log4j's log level to RocksDB so that RocksDB debug log can go to log4j. However, we pass in log level after we create the logger. However, the way it is set isn't effective. This has two impacts: (1) setting DEBUG level don't make RocksDB generate DEBUG level logs; (2) setting WARN or ERROR level does prevent INFO level logging, but RocksDB still makes JNI calls to Scala, which is an unnecessary overhead. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually change the log level and observe the log lines in unit tests. Closes #42354 from siying/rocks_log_level. Authored-by: Siying Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/sql/execution/streaming/state/RocksDB.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index d4366fe732be4..a2868df941178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -611,8 +611,11 @@ class RocksDB( if (log.isWarnEnabled) dbLogLevel = InfoLogLevel.WARN_LEVEL if (log.isInfoEnabled) dbLogLevel = InfoLogLevel.INFO_LEVEL if (log.isDebugEnabled) dbLogLevel = InfoLogLevel.DEBUG_LEVEL - dbOptions.setLogger(dbLogger) + dbLogger.setInfoLogLevel(dbLogLevel) + // The log level set in dbLogger is effective and the one to dbOptions isn't applied to + // customized logger. We still set it as it might show up in RocksDB config file or logging. dbOptions.setInfoLogLevel(dbLogLevel) + dbOptions.setLogger(dbLogger) logInfo(s"Set RocksDB native logging level to $dbLogLevel") dbLogger } From 7493c5764f9644878babacccd4f688fe13ef84aa Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 04:15:07 +0200 Subject: [PATCH 02/30] [SPARK-43429][CONNECT] Add Default & Active SparkSession for Scala Client ### What changes were proposed in this pull request? This adds the `default` and `active` session variables to `SparkSession`: - `default` session is global value. It is typically the first session created through `getOrCreate`. It can be changed through `set` or `clear`. If the session is closed and it is the `default` session we clear the `default` session. - `active` session is a thread local value. It is typically the first session created in this thread or it inherits is value from its parent thread. It can be changed through `set` or `clear`, please note that these methods operate thread locally, so they won't change the parent or children. If the session is closed and it is the `active` session for the current thread then we clear the active value (only for the current thread!). ### Why are the changes needed? To increase compatibility with the existing SparkSession API in `sql/core`. ### Does this PR introduce _any_ user-facing change? Yes. It adds a couple methods that were missing from the Scala Client. ### How was this patch tested? Added tests to `SparkSessionSuite`. Closes #42367 from hvanhovell/SPARK-43429. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/SparkSession.scala | 100 ++++++++++-- .../apache/spark/sql/SparkSessionSuite.scala | 144 ++++++++++++++++-- .../CheckConnectJvmClientCompatibility.scala | 2 - 3 files changed, 225 insertions(+), 21 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 355d7edadc788..7367ed153f7db 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -730,6 +730,23 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet. + */ + private def setDefaultAndActiveSession(session: SparkSession): Unit = { + defaultSession.compareAndSet(null, session) + if (getActiveSession.isEmpty) { + setActiveSession(session) + } + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -742,8 +759,17 @@ object SparkSession extends Logging { */ private[sql] def onSessionClose(session: SparkSession): Unit = { sessions.invalidate(session.client.configuration) + defaultSession.compareAndSet(session, null) + if (getActiveSession.contains(session)) { + clearActiveSession() + } } + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 3.4.0 + */ def builder(): Builder = new Builder() private[sql] lazy val cleaner = { @@ -799,10 +825,15 @@ object SparkSession extends Logging { * * This will always return a newly created session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def create(): SparkSession = { - tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(SparkSession.this.create(builder.configuration)) + setDefaultAndActiveSession(session) + session } /** @@ -811,30 +842,79 @@ object SparkSession extends Logging { * If a session exist with the same configuration that is returned instead of creating a new * session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def getOrCreate(): SparkSession = { - tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(sessions.get(builder.configuration)) + setDefaultAndActiveSession(session) + session } } - def getActiveSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getActiveSession is not supported") + /** + * Returns the default SparkSession. + * + * @since 3.5.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + + /** + * Sets the default SparkSession. + * + * @since 3.5.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) } - def getDefaultSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getDefaultSession is not supported") + /** + * Clears the default SparkSession. + * + * @since 3.5.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) } + /** + * Returns the active SparkSession for the current thread. + * + * @since 3.5.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * an isolated SparkSession. + * + * @since 3.5.0 + */ def setActiveSession(session: SparkSession): Unit = { - throw new UnsupportedOperationException("setActiveSession is not supported") + activeThreadSession.set(session) } + /** + * Clears the active SparkSession for current thread. + * + * @since 3.5.0 + */ def clearActiveSession(): Unit = { - throw new UnsupportedOperationException("clearActiveSession is not supported") + activeThreadSession.remove() } + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 3.5.0 + */ def active: SparkSession = { - throw new UnsupportedOperationException("active is not supported") + getActiveSession + .orElse(getDefaultSession) + .getOrElse(throw new IllegalStateException("No active or default Spark session found")) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 97fb46bf48af4..f06744399f833 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.sql +import java.util.concurrent.{Executors, Phaser} + +import scala.util.control.NonFatal + import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.sql.connect.client.util.ConnectFunSuite @@ -24,6 +28,10 @@ import org.apache.spark.sql.connect.client.util.ConnectFunSuite * Tests for non-dataframe related SparkSession operations. */ class SparkSessionSuite extends ConnectFunSuite { + private val connectionString1: String = "sc://test.it:17845" + private val connectionString2: String = "sc://test.me:14099" + private val connectionString3: String = "sc://doit:16845" + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") @@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite { } test("remote") { - val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate() + val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) session.close() } test("getOrCreate") { - val connectionString = "sc://test.it:17865" - val session1 = SparkSession.builder().remote(connectionString).getOrCreate() - val session2 = SparkSession.builder().remote(connectionString).getOrCreate() + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + val session2 = SparkSession.builder().remote(connectionString1).getOrCreate() try { assert(session1 eq session2) } finally { @@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite { } test("create") { - val connectionString = "sc://test.it:17845" - val session1 = SparkSession.builder().remote(connectionString).create() - val session2 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() try { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) @@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite { } test("newSession") { - val connectionString = "sc://doit:16845" - val session1 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString3).create() val session2 = session1.newSession() try { assert(session1 ne session2) @@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } + session.close() + } + + test("Default/Active session") { + // Make sure we start with a clean slate. + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + intercept[IllegalStateException](SparkSession.active) + + // Create a session + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + assert(SparkSession.active == session1) + + // Create another session... + val session2 = SparkSession.builder().remote(connectionString2).create() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Clear sessions + SparkSession.clearDefaultSession() + assert(SparkSession.getDefaultSession.isEmpty) + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + + // Flip sessions + SparkSession.setActiveSession(session1) + SparkSession.setDefaultSession(session2) + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.contains(session1)) + + // Close session1 + session1.close() + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.isEmpty) + + // Close session2 + session2.close() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + + test("active session in multiple threads") { + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + val phaser = new Phaser(2) + val executor = Executors.newFixedThreadPool(2) + def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = { + executor.submit[Boolean] { () => + try { + block(phaser) + true + } catch { + case NonFatal(e) => + phaser.forceTermination() + throw e + } + } + } + + try { + val script1 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + session1.close() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + val script2 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + internalSession.close() + assert(SparkSession.getActiveSession.isEmpty) + } + assert(script1.get()) + assert(script2.get()) + assert(SparkSession.getActiveSession.contains(session2)) + session2.close() + assert(SparkSession.getActiveSession.isEmpty) + } finally { + executor.shutdown() + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6e577e0f21257..2bf9c41fb2cbd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), From aa1261dc129618d27a1bdc743a5fdd54219f7c01 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 7 Aug 2023 19:16:38 -0700 Subject: [PATCH 03/30] [SPARK-44641][SQL] Incorrect result in certain scenarios when SPJ is not triggered ### What changes were proposed in this pull request? This PR makes sure we use unique partition values when calculating the final partitions in `BatchScanExec`, to make sure no duplicated partitions are generated. ### Why are the changes needed? When `spark.sql.sources.v2.bucketing.pushPartValues.enabled` and `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` are enabled, and SPJ is not triggered, currently Spark will generate incorrect/duplicated results. This is because with both configs enabled, Spark will delay the partition grouping until the time it calculates the final partitions used by the input RDD. To calculate the partitions, it uses partition values from the `KeyGroupedPartitioning` to find out the right ordering for the partitions. However, since grouping is not done when the partition values is computed, there could be duplicated partition values. This means the result could contain duplicated partitions too. ### Does this PR introduce _any_ user-facing change? No, this is a bug fix. ### How was this patch tested? Added a new test case for this scenario. Closes #42324 from sunchao/SPARK-44641. Authored-by: Chao Sun Signed-off-by: Chao Sun --- .../plans/physical/partitioning.scala | 9 ++- .../datasources/v2/BatchScanExec.scala | 9 ++- .../KeyGroupedPartitioningSuite.scala | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index bd8ba54ddd736..456005768bd42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -313,7 +313,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping), and each row in `partitionValues` must be unique. + * mapping). The `partitionValues` may contain duplicated partition values. * * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValues` is * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows @@ -355,6 +355,13 @@ case class KeyGroupedPartitioning( override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = KeyGroupedShuffleSpec(this, distribution) + + lazy val uniquePartitionValues: Seq[InternalRow] = { + partitionValues + .map(InternalRowComparableWrapper(_, expressions)) + .distinct + .map(_.row) + } } object KeyGroupedPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 4b53819739262..eba3c71f871e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -190,10 +190,17 @@ case class BatchScanExec( Seq.fill(numSplits)(Seq.empty)) } } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (row, parts) => InternalRowComparableWrapper(row, p.expressions) -> parts }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 880c30ba9f98d..8461f528277c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1039,4 +1039,60 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44641: duplicated records when SPJ is not triggered") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, items_schema, items_partitions) + sql(s""" + INSERT INTO testcat.ns.$items VALUES + (1, 'aa', 40.0, cast('2020-01-01' as timestamp)), + (1, 'aa', 41.0, cast('2020-01-15' as timestamp)), + (2, 'bb', 10.0, cast('2020-01-01' as timestamp)), + (2, 'bb', 10.5, cast('2020-01-01' as timestamp)), + (3, 'cc', 15.5, cast('2020-02-01' as timestamp))""") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"""INSERT INTO testcat.ns.$purchases VALUES + (1, 42.0, cast('2020-01-01' as timestamp)), + (1, 44.0, cast('2020-01-15' as timestamp)), + (1, 45.0, cast('2020-01-15' as timestamp)), + (2, 11.0, cast('2020-01-01' as timestamp)), + (3, 19.5, cast('2020-02-01' as timestamp))""") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClusteredEnabled => + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClusteredEnabled.toString) { + + // join keys are not the same as the partition keys, therefore SPJ is not triggered. + val df = sql( + s""" + SELECT id, name, i.price as purchase_price, p.item_id, p.price as sale_price + FROM testcat.ns.$items i JOIN testcat.ns.$purchases p + ON i.arrive_time = p.time ORDER BY id, purchase_price, p.item_id, sale_price + """) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, "shuffle should exist when SPJ is not used") + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } } From 6dadd188f3652816c291919a2413f73c13bb1b47 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 8 Aug 2023 11:04:53 +0800 Subject: [PATCH 04/30] [SPARK-44554][INFRA] Make Python linter related checks pass of branch-3.3/3.4 daily testing ### What changes were proposed in this pull request? The daily testing of `branch-3.3/3.4` uses the same yml file as the master now and the upgrade to `MyPy` in https://github.com/apache/spark/pull/41690 resulted in Python linter check failure of `branch-3.3/3.4`, - branch-3.3: https://github.com/apache/spark/actions/runs/5677524469/job/15386025539 - branch-3.4: https://github.com/apache/spark/actions/runs/5678626664/job/15389273919 image So this pr do the following change for workaround: 1. Install different Python linter dependencies for `branch-3.3/3.4`, the dependency list comes from the corresponding branch to ensure compatibility with the version 2. Skip `Install dependencies for Python code generation check` and `Python code generation check` for `branch-3.3/3.4` due to they do not use `Buf remote plugins` and `Buf remote generation` is no longer supported. Meanwhile, the protobuf files in the branch generally do not change, so we can skip this check. ### Why are the changes needed? Make Python linter related checks pass of branch-3.3/3.4 daily testing ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manually checked branch-3.4, the newly added condition should be ok Closes #42167 from LuciferYang/SPARK-44554. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: yangjie01 --- .github/workflows/build_and_test.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index cd68c0904d9a4..b4559dea42bb9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -657,7 +657,22 @@ jobs: - name: Spark connect jvm client mima check if: inputs.branch != 'branch-3.3' run: ./dev/connect-jvm-client-mima-check + - name: Install Python linter dependencies for branch-3.3 + if: inputs.branch == 'branch-3.3' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/073d0b60d31bf68ebacdc005f59b928a5902670f/.github/workflows/build_and_test.yml#L501-L508 + # Should delete this section after SPARK 3.3 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==21.12b0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' + - name: Install Python linter dependencies for branch-3.4 + if: inputs.branch == 'branch-3.4' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578 + # Should delete this section after SPARK 3.4 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python linter dependencies + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes. # See also https://github.com/sphinx-doc/sphinx/issues/7551. @@ -668,6 +683,7 @@ jobs: - name: Python linter run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python - name: Install dependencies for Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # See more in "Installation" https://docs.buf.build/installation#tarball curl -LO https://github.com/bufbuild/buf/releases/download/v1.24.0/buf-Linux-x86_64.tar.gz @@ -676,6 +692,7 @@ jobs: rm buf-Linux-x86_64.tar.gz python3.9 -m pip install 'protobuf==3.20.3' 'mypy-protobuf==3.3.0' - name: Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi - name: Install JavaScript linter dependencies run: | From 25053d98186489d9f2061c9b815a5a33f7e309c4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 8 Aug 2023 11:06:21 +0800 Subject: [PATCH 05/30] [SPARK-44689][CONNECT] Make the exception handling of function `SparkConnectPlanner#unpackScalarScalaUDF` more universal ### What changes were proposed in this pull request? This PR changes the exception handling in the `unpackScalarScalaUD` function in `SparkConnectPlanner` from determining the exception type based on a fixed nesting level to using Guava `Throwables` to get the root cause and then determining the type of the root cause. This makes it compatible with differences between different Java versions. ### Why are the changes needed? The following failure occurred when testing `UDFClassLoadingE2ESuite` in Java 17 daily test: https://github.com/apache/spark/actions/runs/5766913899/job/15635782831 ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session *** FAILED *** (101 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:57) ... [info] - update class loader after stubbing: same session *** FAILED *** (52 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:73) ... ``` After analysis, it was found that there are differences in the exception stack generated on the server side between Java 8 and Java 17: - Java 8 ``` java.io.IOException: unexpected exception type at java.io.ObjectStreamClass.throwMiscException(ObjectStreamClass.java:1750) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1280) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2222) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461) at org.apache.spark.util.SparkSerDeUtils.deserialize(SparkSerDeUtils.scala:50) at org.apache.spark.util.SparkSerDeUtils.deserialize$(SparkSerDeUtils.scala:41) at org.apache.spark.util.Utils$.deserialize(Utils.scala:95) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.unpackScalarScalaUDF(SparkConnectPlanner.scala:1516) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.org$apache$spark$sql$connect$planner$SparkConnectPlanner$$unpackUdf(SparkConnectPlanner.scala:1507) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformScalarScalaFunction(SparkConnectPlanner.scala:1544) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterScalarScalaUDF(SparkConnectPlanner.scala:2565) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterUserDefinedFunction(SparkConnectPlanner.scala:2492) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2363) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:184) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:184) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:171) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:179) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:170) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:183) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:227) Caused by: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.lang.Class.getDeclaredMethod(Class.java:2130) at java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:224) at java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:221) at java.security.AccessController.doPrivileged(Native Method) at java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:221) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1274) ... 31 more ``` - Java 17 ``` java.lang.RuntimeException: Exception in SerializedLambda.readResolve at java.base/java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:288) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:568) at java.base/java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1190) at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2266) at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733) at java.base/java.io.ObjectInputStream$FieldValues.(ObjectInputStream.java:2606) at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2457) at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2257) at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733) at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:509) at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:467) at org.apache.spark.util.SparkSerDeUtils.deserialize(SparkSerDeUtils.scala:50) at org.apache.spark.util.SparkSerDeUtils.deserialize$(SparkSerDeUtils.scala:41) at org.apache.spark.util.Utils$.deserialize(Utils.scala:95) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.unpackScalarScalaUDF(SparkConnectPlanner.scala:1517) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.org$apache$spark$sql$connect$planner$SparkConnectPlanner$$unpackUdf(SparkConnectPlanner.scala:1507) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformScalarScalaFunction(SparkConnectPlanner.scala:1552) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterScalarScalaUDF(SparkConnectPlanner.scala:2573) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterUserDefinedFunction(SparkConnectPlanner.scala:2500) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2371) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:184) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:184) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:171) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:179) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:170) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:183) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:227) Caused by: java.security.PrivilegedActionException: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.base/java.security.AccessController.doPrivileged(AccessController.java:573) at java.base/java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:269) ... 36 more Caused by: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.base/java.lang.Class.getDeclaredMethod(Class.java:2675) at java.base/java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:272) at java.base/java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:269) at java.base/java.security.AccessController.doPrivileged(AccessController.java:569) ... 37 more ``` While their root exceptions are both `NoSuchMethodException`, the levels of nesting are different. We can add an exception check branch to make it compatible with Java 17, for example: ```scala case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => throw new ClassNotFoundException(... ${e.getCause} ...) case e: RuntimeException if e.getCause != null && e.getCause.getCause.isInstanceOf[NoSuchMethodException] => throw new ClassNotFoundException(... ${e.getCause.getCause} ...) ``` But if future Java versions change the nested levels of exceptions again, this will necessitate another modification of this part of the code. Therefore, this PR has been revised to fetch the root cause of the exception and conduct a type check on the root cause to make it as universal as possible. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass Git Hub Actions - Manually check with Java 17 ``` java -version openjdk version "17.0.8" 2023-07-18 LTS OpenJDK Runtime Environment Zulu17.44+15-CA (build 17.0.8+7-LTS) OpenJDK 64-Bit Server VM Zulu17.44+15-CA (build 17.0.8+7-LTS, mixed mode, sharing) ``` run ``` build/sbt clean "connect-client-jvm/testOnly *UDFClassLoadingE2ESuite" -Phive ``` Before ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session *** FAILED *** (60 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:57) ... [info] - update class loader after stubbing: same session *** FAILED *** (15 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:73) ... [info] Run completed in 9 seconds, 565 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 0, failed 2, canceled 0, ignored 0, pending 0 [info] *** 2 TESTS FAILED *** [error] Failed tests: [error] org.apache.spark.sql.connect.client.UDFClassLoadingE2ESuite [error] (connect-client-jvm / Test / testOnly) sbt.TestsFailedException: Tests unsuccessful ``` After ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session (116 milliseconds) [info] - update class loader after stubbing: same session (41 milliseconds) [info] Run completed in 9 seconds, 781 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` Closes #42360 from LuciferYang/unpackScalarScalaUDF-exception-java17. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../connect/planner/SparkConnectPlanner.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7136476b515f9..f70a17e580a3e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.connect.planner -import java.io.IOException - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try +import com.google.common.base.Throwables import com.google.common.collect.{Lists, Maps} import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} @@ -1518,11 +1517,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}") Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) } catch { - case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => - throw new ClassNotFoundException( - s"Failed to load class correctly due to ${e.getCause}. " + - "Make sure the artifact where the class is defined is installed by calling" + - " session.addArtifact.") + case t: Throwable => + Throwables.getRootCause(t) match { + case nsm: NoSuchMethodException => + throw new ClassNotFoundException( + s"Failed to load class correctly due to $nsm. " + + "Make sure the artifact where the class is defined is installed by calling" + + " session.addArtifact.") + case _ => throw t + } } } From 590b77f76284ad03ad8b3b6d30b23983c66513fc Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 8 Aug 2023 11:09:58 +0800 Subject: [PATCH 06/30] [SPARK-44005][PYTHON] Improve error messages for regular Python UDTFs that return non-tuple values ### What changes were proposed in this pull request? This PR improves error messages for regular Python UDTFs when the result rows are not one of tuple, list and dict. Note this is supported when arrow optimization is enabled. ### Why are the changes needed? To make Python UDTFs more user friendly. ### Does this PR introduce _any_ user-facing change? Yes. ``` class TestUDTF: def eval(self, a: int): yield a ``` Before this PR, this will fail with this error `Unexpected tuple 1 with StructType` After this PR, this will have a more user-friendly error: `[UDTF_INVALID_OUTPUT_ROW_TYPE] The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got 'int'. Please make sure that the output rows are of the correct type.` ### How was this patch tested? Existing UTs. Closes #42353 from allisonwang-db/spark-44005-non-tuple-return-val. Authored-by: allisonwang-db Signed-off-by: Ruifeng Zheng --- python/pyspark/errors/error_classes.py | 5 +++++ python/pyspark/sql/tests/test_udtf.py | 26 +++++++++++--------------- python/pyspark/worker.py | 12 +++++++++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 24885e94d3255..bc32afeb87a9f 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -743,6 +743,11 @@ "User defined table function encountered an error in the '' method: " ] }, + "UDTF_INVALID_OUTPUT_ROW_TYPE" : { + "message" : [ + "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." + ] + }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index b2f473996bcb6..300067716e9de 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -163,24 +163,21 @@ def eval(self, a: int, b: int): self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)]) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): return (a,) - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() def test_udtf_with_invalid_return_value(self): @udtf(returnType="x: int") @@ -1852,21 +1849,20 @@ def eval(self): self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") # When arrow is enabled, it can handle non-tuple return value. - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): - return (a,) + return [a] - func = udtf(TestUDTF, returnType="a: int") - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) def test_numeric_output_type_casting(self): class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b32e20e3b0418..6f27400387e72 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -648,9 +648,8 @@ def wrap_udtf(f, return_type): return_type_size = len(return_type) def verify_and_convert_result(result): - # TODO(SPARK-44005): support returning non-tuple values - if result is not None and hasattr(result, "__len__"): - if len(result) != return_type_size: + if result is not None: + if hasattr(result, "__len__") and len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ @@ -658,6 +657,13 @@ def verify_and_convert_result(result): "actual": str(len(result)), }, ) + + if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): + raise PySparkRuntimeError( + error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", + message_parameters={"type": type(result).__name__}, + ) + return toInternal(result) # Evaluate the function and return a tuple back to the executor. From b4b91212b1d4ce8f47f9e1abeb26b06122c01f13 Mon Sep 17 00:00:00 2001 From: Shuyou Dong Date: Tue, 8 Aug 2023 12:17:53 +0900 Subject: [PATCH 07/30] [SPARK-44703][CORE] Log eventLog rewrite duration when compact old event log files ### What changes were proposed in this pull request? Log eventLog rewrite duration when compact old event log files. ### Why are the changes needed? When enable `spark.eventLog.rolling.enabled` and the number of eventLog files exceeds the value of `spark.history.fs.eventLog.rolling.maxFilesToRetain`, HistoryServer will compact the old event log files into one compact file. Currently there is no log the rewrite duration in rewrite method, this metric is useful for understand the compact duration, so we need add logs in the method. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test. Closes #42378 from shuyouZZ/SPARK-44703. Authored-by: Shuyou Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/deploy/history/EventLogFileCompactor.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala index 8558f765175fc..27040e83533ff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala @@ -149,6 +149,7 @@ class EventLogFileCompactor( val logWriter = new CompactedEventLogFileWriter(lastIndexEventLogPath, "dummy", None, lastIndexEventLogPath.getParent.toUri, sparkConf, hadoopConf) + val startTime = System.currentTimeMillis() logWriter.start() eventLogFiles.foreach { file => EventFilter.applyFilterToFile(fs, filters, file.getPath, @@ -158,6 +159,8 @@ class EventLogFileCompactor( ) } logWriter.stop() + val duration = System.currentTimeMillis() - startTime + logInfo(s"Finished rewriting eventLog files to ${logWriter.logPath} took $duration ms.") logWriter.logPath } From d2b60ff51fabdb38899e649aa2e700112534d79c Mon Sep 17 00:00:00 2001 From: itholic Date: Tue, 8 Aug 2023 16:16:11 +0900 Subject: [PATCH 08/30] [SPARK-43567][PS] Support `use_na_sentinel` for `factorize` ### What changes were proposed in this pull request? This PR proposes to support `use_na_sentinel` for `factorize`. ### Why are the changes needed? To match the behavior with [pandas 2](https://pandas.pydata.org/docs/dev/whatsnew/v2.0.0.html) ### Does this PR introduce _any_ user-facing change? Yes, the `na_sentinel` is removed in favor of `use_na_sentinel`. ### How was this patch tested? Enabling the existing tests. Closes #42270 from itholic/pandas_use_na_sentinel. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 1 + python/pyspark/pandas/base.py | 39 +++++++------------ .../connect/series/test_parity_compute.py | 4 ++ .../pandas/tests/indexes/test_category.py | 8 +--- .../pandas/tests/series/test_compute.py | 20 ++++------ 5 files changed, 29 insertions(+), 43 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 7a691ee264571..d26f1cbbe0dc4 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -29,6 +29,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``Series.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. * In Spark 4.0, ``DataFrame.mad`` has been removed from pandas API on Spark. * In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. +* In Spark 4.0, ``na_sentinel`` parameter from ``Index.factorize`` and `Series.factorize`` has been removed from pandas API on Spark, use ``use_na_sentinel`` instead. Upgrading from PySpark 3.3 to 3.4 diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 2de260e6e9351..0685af769872a 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -1614,7 +1614,7 @@ def take(self: IndexOpsLike, indices: Sequence[int]) -> IndexOpsLike: return cast(IndexOpsLike, self._psdf.iloc[indices].index) def factorize( - self: IndexOpsLike, sort: bool = True, na_sentinel: Optional[int] = -1 + self: IndexOpsLike, sort: bool = True, use_na_sentinel: bool = True ) -> Tuple[IndexOpsLike, pd.Index]: """ Encode the object as an enumerated type or categorical variable. @@ -1625,11 +1625,11 @@ def factorize( Parameters ---------- sort : bool, default True - na_sentinel : int or None, default -1 - Value to mark "not found". If None, will not drop the NaN - from the uniques of the values. - - .. deprecated:: 3.4.0 + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values, effectively assigning them + a distinct category. If False, NaN values will be encoded as non-negative integers, + treating them as unique categories in the encoding process and retaining them in the + set of unique categories in the data. Returns ------- @@ -1658,7 +1658,7 @@ def factorize( >>> uniques Index(['a', 'b', 'c'], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=None) + >>> codes, uniques = psser.factorize(use_na_sentinel=False) >>> codes 0 1 1 3 @@ -1669,17 +1669,6 @@ def factorize( >>> uniques Index(['a', 'b', 'c', None], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=-2) - >>> codes - 0 1 - 1 -2 - 2 0 - 3 2 - 4 1 - dtype: int32 - >>> uniques - Index(['a', 'b', 'c'], dtype='object') - For Index: >>> psidx = ps.Index(['b', None, 'a', 'c', 'b']) @@ -1691,8 +1680,8 @@ def factorize( """ from pyspark.pandas.series import first_series - assert (na_sentinel is None) or isinstance(na_sentinel, int) assert sort is True + use_na_sentinel = -1 if use_na_sentinel else False # type: ignore[assignment] warnings.warn( "Argument `na_sentinel` will be removed in 4.0.0.", @@ -1716,7 +1705,7 @@ def factorize( scol = map_scol[self.spark.column] codes, uniques = self._with_new_scol( scol.alias(self._internal.data_spark_column_names[0]) - ).factorize(na_sentinel=na_sentinel) + ).factorize(use_na_sentinel=use_na_sentinel) return codes, uniques.astype(self.dtype) uniq_sdf = self._internal.spark_frame.select(self.spark.column).distinct() @@ -1743,13 +1732,13 @@ def factorize( # Constructs `unique_to_code` mapping non-na unique to code unique_to_code = {} - if na_sentinel is not None: - na_sentinel_code = na_sentinel + if use_na_sentinel: + na_sentinel_code = use_na_sentinel code = 0 for unique in uniques_list: if pd.isna(unique): - if na_sentinel is None: - na_sentinel_code = code + if not use_na_sentinel: + na_sentinel_code = code # type: ignore[assignment] else: unique_to_code[unique] = code code += 1 @@ -1767,7 +1756,7 @@ def factorize( codes = self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0])) - if na_sentinel is not None: + if use_na_sentinel: # Drops the NaN from the uniques of the values uniques_list = [x for x in uniques_list if not pd.isna(x)] diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py index 8876fcb139885..31916f12b4e7f 100644 --- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py @@ -24,6 +24,10 @@ class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): pass + @unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.") + def test_factorize(self): + super().test_factorize() + if __name__ == "__main__": from pyspark.pandas.tests.connect.series.test_parity_compute import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index ffffae828c437..6aa92b7e1e390 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -210,10 +210,6 @@ def test_astype(self): self.assert_eq(pscidx.astype(str), pcidx.astype(str)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43567): Enable CategoricalIndexTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pidx = pd.CategoricalIndex([1, 2, 3, None]) psidx = ps.from_pandas(pidx) @@ -224,8 +220,8 @@ def test_factorize(self): self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) - pcodes, puniques = pidx.factorize(na_sentinel=-2) - kcodes, kuniques = psidx.factorize(na_sentinel=-2) + pcodes, puniques = pidx.factorize(use_na_sentinel=-2) + kcodes, kuniques = psidx.factorize(use_na_sentinel=-2) self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) diff --git a/python/pyspark/pandas/tests/series/test_compute.py b/python/pyspark/pandas/tests/series/test_compute.py index 155649179e6ef..784bf29e1a25b 100644 --- a/python/pyspark/pandas/tests/series/test_compute.py +++ b/python/pyspark/pandas/tests/series/test_compute.py @@ -407,10 +407,6 @@ def test_abs(self): self.assert_eq(abs(psser), abs(pser)) self.assert_eq(np.abs(psser), np.abs(pser)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43550): Enable SeriesTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pser = pd.Series(["a", "b", "a", "b"]) psser = ps.from_pandas(pser) @@ -492,27 +488,27 @@ def test_factorize(self): pser = pd.Series(["a", "b", "a", np.nan, None]) psser = ps.from_pandas(pser) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=-2) - kcodes, kuniques = psser.factorize(na_sentinel=-2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=-2) + kcodes, kuniques = psser.factorize(use_na_sentinel=-2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=2) - kcodes, kuniques = psser.factorize(na_sentinel=2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=2) + kcodes, kuniques = psser.factorize(use_na_sentinel=2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) if not pd_below_1_1_2: - pcodes, puniques = pser.factorize(sort=True, na_sentinel=None) - kcodes, kuniques = psser.factorize(na_sentinel=None) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) # puniques is Index(['a', 'b', nan], dtype='object') self.assert_eq(ps.Index(["a", "b", None]), kuniques) psser = ps.Series([1, 2, np.nan, 4, 5]) # Arrow takes np.nan as null psser.loc[3] = np.nan # Spark takes np.nan as NaN - kcodes, kuniques = psser.factorize(na_sentinel=None) - pcodes, puniques = psser._to_pandas().factorize(sort=True, na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) + pcodes, puniques = psser._to_pandas().factorize(sort=True, use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) From f7879b4c2500046cd7d889ba94adedd3000f8c41 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 8 Aug 2023 13:26:19 +0500 Subject: [PATCH 09/30] [SPARK-44680][SQL] Improve the error for parameters in `DEFAULT` ### What changes were proposed in this pull request? In the PR, I propose to check that `DEFAULT` clause contains a parameter. If so, raise appropriate error about the feature is not supported. Currently, table creation with `DEFAULT` containing any parameters finishes successfully even parameters are not supported in such case: ```sql scala> spark.sql("CREATE TABLE t12(c1 int default :parm)", args = Map("parm" -> 5)).show() ++ || ++ ++ scala> spark.sql("describe t12"); org.apache.spark.sql.AnalysisException: [INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION] Failed to execute EXISTS_DEFAULT command because the destination table column `c1` has a DEFAULT value :parm, which fails to resolve as a valid expression. ``` ### Why are the changes needed? This improves user experience with Spark SQL by saying about the root cause of the issue. ### Does this PR introduce _any_ user-facing change? Yes. After the change, the table creation completes w/ the error: ```sql scala> spark.sql("CREATE TABLE t12(c1 int default :parm)", args = Map("parm" -> 5)).show() org.apache.spark.sql.catalyst.parser.ParseException: [UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT] The feature is not supported: Parameter markers are not allowed in DEFAULT.(line 1, pos 32) == SQL == CREATE TABLE t12(c1 int default :parm) --------------------------------^^^ ``` ### How was this patch tested? By running new test: ``` $ build/sbt "test:testOnly *ParametersSuite" ``` Closes #42365 from MaxGekk/fix-param-in-DEFAULT. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../spark/sql/catalyst/parser/AstBuilder.scala | 12 ++++++++---- .../org/apache/spark/sql/ParametersSuite.scala | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1b9dda51bf077..0635e6a1b44fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} @@ -3153,9 +3154,12 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { ctx.asScala.headOption.map(visitLocationSpec) } - private def verifyAndGetExpression(exprCtx: ExpressionContext): String = { + private def verifyAndGetExpression(exprCtx: ExpressionContext, place: String): String = { // Make sure it can be converted to Catalyst expressions. - expression(exprCtx) + val expr = expression(exprCtx) + if (expr.containsPattern(PARAMETER)) { + throw QueryParsingErrors.parameterMarkerNotAllowed(place, expr.origin) + } // Extract the raw expression text so that we can save the user provided text. We don't // use `Expression.sql` to avoid storing incorrect text caused by bugs in any expression's // `sql` method. Note: `exprCtx.getText` returns a string without spaces, so we need to @@ -3170,7 +3174,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { */ override def visitDefaultExpression(ctx: DefaultExpressionContext): String = withOrigin(ctx) { - verifyAndGetExpression(ctx.expression()) + verifyAndGetExpression(ctx.expression(), "DEFAULT") } /** @@ -3178,7 +3182,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { */ override def visitGenerationExpression(ctx: GenerationExpressionContext): String = withOrigin(ctx) { - verifyAndGetExpression(ctx.expression()) + verifyAndGetExpression(ctx.expression(), "GENERATED") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 1ab9dce1c94ec..a72c9a600adea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -487,4 +487,19 @@ class ParametersSuite extends QueryTest with SharedSparkSession { start = 7, stop = 13)) } + + test("SPARK-44680: parameters in DEFAULT") { + checkError( + exception = intercept[AnalysisException] { + spark.sql( + "CREATE TABLE t11(c1 int default :parm) USING parquet", + args = Map("parm" -> 5)) + }, + errorClass = "UNSUPPORTED_FEATURE.PARAMETER_MARKER_IN_UNEXPECTED_STATEMENT", + parameters = Map("statement" -> "DEFAULT"), + context = ExpectedContext( + fragment = "default :parm", + start = 24, + stop = 36)) + } } From f9d417fc17a82ddf02d6bbab82abc8e1aa264356 Mon Sep 17 00:00:00 2001 From: vicennial Date: Tue, 8 Aug 2023 17:30:16 +0900 Subject: [PATCH 10/30] [SPARK-44657][CONNECT] Fix incorrect limit handling in ArrowBatchWithSchemaIterator and config parsing of CONNECT_GRPC_ARROW_MAX_BATCH_SIZE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Fixes the limit checking of `maxEstimatedBatchSize` and `maxRecordsPerBatch` to respect the more restrictive limit and fixes the config parsing of `CONNECT_GRPC_ARROW_MAX_BATCH_SIZE` by converting the value to bytes. ### Why are the changes needed? Bugfix. In the arrow writer [code](https://github.com/apache/spark/blob/6161bf44f40f8146ea4c115c788fd4eaeb128769/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala#L154-L163) , the conditions don’t seem to hold what the documentation says regd "maxBatchSize and maxRecordsPerBatch, respect whatever smaller" since it seems to actually respect the conf which is "larger" (i.e less restrictive) due to || operator. Further, when the `CONNECT_GRPC_ARROW_MAX_BATCH_SIZE` conf is read, the value is not converted to bytes from MiB ([example](https://github.com/apache/spark/blob/3e5203c64c06cc8a8560dfa0fb6f52e74589b583/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala#L103)). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #42321 from vicennial/SPARK-44657. Authored-by: vicennial Signed-off-by: Hyukjin Kwon --- .../spark/sql/connect/config/Connect.scala | 10 +-- .../planner/SparkConnectServiceSuite.scala | 61 +++++++++++++++++++ .../sql/execution/arrow/ArrowConverters.scala | 21 ++++--- 3 files changed, 79 insertions(+), 13 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 64c2d6f1cb623..e25cb5cbab279 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -49,12 +49,12 @@ object Connect { val CONNECT_GRPC_ARROW_MAX_BATCH_SIZE = ConfigBuilder("spark.connect.grpc.arrow.maxBatchSize") .doc( - "When using Apache Arrow, limit the maximum size of one arrow batch that " + - "can be sent from server side to client side. Currently, we conservatively use 70% " + - "of it because the size is not accurate but estimated.") + "When using Apache Arrow, limit the maximum size of one arrow batch, in bytes unless " + + "otherwise specified, that can be sent from server side to client side. Currently, we " + + "conservatively use 70% of it because the size is not accurate but estimated.") .version("3.4.0") - .bytesConf(ByteUnit.MiB) - .createWithDefaultString("4m") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(4 * 1024 * 1024) val CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE = ConfigBuilder("spark.connect.grpc.maxInboundMessageSize") diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index e833d12c4f595..285f3103b190b 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -238,6 +238,67 @@ class SparkConnectServiceSuite extends SharedSparkSession with MockitoSugar with } } + test("SPARK-44657: Arrow batches respect max batch size limit") { + // Set 10 KiB as the batch size limit + val batchSize = 10 * 1024 + withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> batchSize.toString) { + // TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21 + assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select * from range(0, 15000, 1, 1)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .setSessionId(UUID.randomUUID.toString()) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = { + responses += v + } + + override def onError(throwable: Throwable): Unit = { + throw throwable + } + + override def onCompleted(): Unit = { + done = true + } + }) + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 1 schema + 1 metric + at least 2 data batches + assert(responses.size > 3) + + val allocator = new RootAllocator() + + // Check the 'data' batches + responses.tail.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + // Batch size must be <= 70% since we intentionally use this multiplier for the size + // estimator. + assert(batch.getData.size() <= batchSize * 0.7) + } + } + } + gridTest("SPARK-43923: commands send events")( Seq( proto.Command diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 59d931bbe4849..86dd7984b5859 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -150,17 +150,22 @@ private[sql] object ArrowConverters extends Logging { // Always write the schema. MessageSerializer.serialize(writeChannel, arrowSchema) + def isBatchSizeLimitExceeded: Boolean = { + // If `maxEstimatedBatchSize` is zero or negative, it implies unlimited. + maxEstimatedBatchSize > 0 && estimatedBatchSize >= maxEstimatedBatchSize + } + def isRecordLimitExceeded: Boolean = { + // If `maxRecordsPerBatch` is zero or negative, it implies unlimited. + maxRecordsPerBatch > 0 && rowCountInLastBatch >= maxRecordsPerBatch + } // Always write the first row. while (rowIter.hasNext && ( - // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. // If the size in bytes is positive (set properly), always write the first row. - rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || - // If the size in bytes of rows are 0 or negative, unlimit it. - estimatedBatchSize <= 0 || - estimatedBatchSize < maxEstimatedBatchSize || - // If the size of rows are 0 or negative, unlimit it. - maxRecordsPerBatch <= 0 || - rowCountInLastBatch < maxRecordsPerBatch)) { + (rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0) || + // If either limit is hit, create a batch. This implies that the limit that is hit first + // triggers the creation of a batch even if the other limit is not yet hit, hence + // preferring the more restrictive limit. + (!isBatchSizeLimitExceeded && !isRecordLimitExceeded))) { val row = rowIter.next() arrowWriter.write(row) estimatedBatchSize += (row match { From 29e8331681c6214390f426806d19ee9673b073e1 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Tue, 8 Aug 2023 17:07:04 +0800 Subject: [PATCH 11/30] [SPARK-44714] Ease restriction of LCA resolution regarding queries with having ### What changes were proposed in this pull request? This PR eases some restriction of LCA resolution regarding queries with having. Previously LCA won't rewrite (to the new plan shape) when the whole queries contains `UnresolvedHaving`, in case it breaks the plan shape of `UnresolvedHaving - Aggregate` that can be recognized by other rules. But this limitation is too strict and it causes some deadlock in having - lca - window queries. See https://issues.apache.org/jira/browse/SPARK-42936 for more details and examples. With this PR, it will only skip LCA resolution on the `Aggregate` whose direct parent is `UnresolvedHaving`. This is enabled by a new bottom-up resolution without using the transform or resolve utility function. This PR also recognizes a vulnerability related to `TEMP_RESOVLED_COLUMN` and comments in the code. It should be considered as future work. ### Why are the changes needed? More complete functionality and better user experience. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. Closes #42276 from anchovYu/lca-limitation-better-error. Authored-by: Xinyi Yu Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 2 + .../ResolveLateralColumnAliasReference.scala | 200 ++++++++++-------- .../spark/sql/LateralColumnAliasSuite.scala | 145 +++++++++++++ 3 files changed, 261 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6c1d774a1b5fd..0c792ded8f890 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -285,6 +285,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: + // Please do not insert any other rules in between. See the TODO comments in rule + // ResolveLateralColumnAliasReference for more details. ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 5d89de0008478..c249a3506f2d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.WindowExpression.hasWindowExpre import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -131,95 +131,97 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { (pList.exists(hasWindowExpression) && p.expressions.forall(_.resolved) && p.childrenResolved) } - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { - plan - } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING)) { - // It should not change the plan if `TempResolvedColumn` or `UnresolvedHaving` is present in - // the query plan. These plans need certain plan shape to get recognized and resolved by other - // rules, such as Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions. - // LCA resolution can break the plan shape, like adding Project above Aggregate. - plan - } else { - // phase 2: unwrap - plan.resolveOperatorsUpWithPruning( - _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { - case p @ Project(projectList, child) if ruleApplicableOnOperator(p, projectList) - && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - var aliasMap = AttributeMap.empty[AliasEntry] - val referencedAliases = collection.mutable.Set.empty[AliasEntry] - def unwrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => - val aliasEntry = aliasMap.get(lcaRef.a).get - // If there is no chaining of lateral column alias reference, push down the alias - // and unwrap the LateralColumnAliasReference to the NamedExpression inside - // If there is chaining, don't resolve and save to future rounds - if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - referencedAliases += aliasEntry - lcaRef.ne - } else { - lcaRef - } - case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => - // It shouldn't happen, but restore to unresolved attribute to be safe. - UnresolvedAttribute(lcaRef.nameParts) - }.asInstanceOf[NamedExpression] - } - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaResolved = unwrapLCAReference(a) - // Insert the original alias instead of rewritten one to detect chained LCA - aliasMap += (a.toAttribute -> AliasEntry(a, idx)) - lcaResolved - case (e, _) => - unwrapLCAReference(e) - } + /** Internal application method. A hand-written bottom-up recursive traverse. */ + private def apply0(plan: LogicalPlan): LogicalPlan = { + plan match { + case p: LogicalPlan if !p.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE) => + p - if (referencedAliases.isEmpty) { - p - } else { - val outerProjectList = collection.mutable.Seq(newProjectList: _*) - val innerProjectList = - collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) - referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => - outerProjectList.update(idx, alias.toAttribute) - innerProjectList += alias - } - p.copy( - projectList = outerProjectList.toSeq, - child = Project(innerProjectList.toSeq, child) - ) - } + // It should not change the Aggregate (and thus the plan shape) if its parent is an + // UnresolvedHaving, to avoid breaking the shape pattern `UnresolvedHaving - Aggregate` + // matched by ResolveAggregateFunctions. See SPARK-42936 and SPARK-44714 for more details. + case u @ UnresolvedHaving(_, agg: Aggregate) => + u.copy(child = agg.mapChildren(apply0)) - case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) - if ruleApplicableOnOperator(agg, aggregateExpressions) - && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + case pOriginal: Project if ruleApplicableOnOperator(pOriginal, pOriginal.projectList) + && pOriginal.projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + val p @ Project(projectList, child) = pOriginal.mapChildren(apply0) + var aliasMap = AttributeMap.empty[AliasEntry] + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap.get(lcaRef.a).get + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + referencedAliases += aliasEntry + lcaRef.ne + } else { + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } - // Check if current Aggregate is eligible to lift up with Project: the aggregate - // expression only contains: 1) aggregate functions, 2) grouping expressions, 3) leaf - // expressions excluding attributes not in grouping expressions - // This check is to prevent unnecessary transformation on invalid plan, to guarantee it - // throws the same exception. For example, cases like non-aggregate expressions not - // in group by, once transformed, will throw a different exception: missing input. - def eligibleToLiftUp(exp: Expression): Boolean = { - exp match { - case _: AggregateExpression => true - case e if groupingExpressions.exists(_.semanticEquals(e)) => true - case a: Attribute => false - case s: ScalarSubquery if s.children.nonEmpty - && !groupingExpressions.exists(_.semanticEquals(s)) => false - // Manually skip detection on function itself because it can be an aggregate function. - // This is to avoid expressions like sum(salary) over () eligible to lift up. - case WindowExpression(function, spec) => - function.children.forall(eligibleToLiftUp) && eligibleToLiftUp(spec) - case e => e.children.forall(eligibleToLiftUp) - } - } - if (!aggregateExpressions.forall(eligibleToLiftUp)) { - return agg + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = collection.mutable.Seq(newProjectList: _*) + val innerProjectList = + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + + case aggOriginal: Aggregate + if ruleApplicableOnOperator(aggOriginal, aggOriginal.aggregateExpressions) + && aggOriginal.aggregateExpressions.exists( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + val agg @ Aggregate(groupingExpressions, aggregateExpressions, _) = + aggOriginal.mapChildren(apply0) + // Check if current Aggregate is eligible to lift up with Project: the aggregate + // expression only contains: 1) aggregate functions, 2) grouping expressions, 3) leaf + // expressions excluding attributes not in grouping expressions + // This check is to prevent unnecessary transformation on invalid plan, to guarantee it + // throws the same exception. For example, cases like non-aggregate expressions not + // in group by, once transformed, will throw a different exception: missing input. + def eligibleToLiftUp(exp: Expression): Boolean = { + exp match { + case _: AggregateExpression => true + case e if groupingExpressions.exists(_.semanticEquals(e)) => true + case a: Attribute => false + case s: ScalarSubquery if s.children.nonEmpty + && !groupingExpressions.exists(_.semanticEquals(s)) => false + // Manually skip detection on function itself because it can be an aggregate function. + // This is to avoid expressions like sum(salary) over () eligible to lift up. + case WindowExpression(function, spec) => + function.children.forall(eligibleToLiftUp) && eligibleToLiftUp(spec) + case e => e.children.forall(eligibleToLiftUp) + } + } + if (!aggregateExpressions.forall(eligibleToLiftUp)) { + agg + } else { val newAggExprs = collection.mutable.Set.empty[NamedExpression] val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression] // Extract the expressions to keep in the Aggregate. Return the transformed expression @@ -262,7 +264,33 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { projectList = projectExprs, child = agg.copy(aggregateExpressions = newAggExprs.toSeq) ) - } + } + + case p: LogicalPlan => + p.mapChildren(apply0) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN)) { + // It should not change the plan if `TempResolvedColumn` is present in the query plan. These + // plans need certain plan shape to get recognized and resolved by other rules, such as + // Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions. LCA resolution can + // break the plan shape, like adding Project above Aggregate. + // TODO: this condition only guarantees to keep the shape after the plan has + // `TempResolvedColumn`. However, it does not consider the case of breaking the shape even + // before `TempResolvedColumn` is generated by matching Filter/Sort - Aggregate in + // ResolveReferences. Currently the correctness of this case now relies on the rule + // application order, that ResolveReference is right before the application of + // ResolveLateralColumnAliasReference. The condition in the two rules guarantees that the + // case can never happen. We should consider to remove this order dependency but still assure + // correctness in the future. + plan + } else { + // phase 2: unwrap + apply0(plan) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 1e3a0d70c7fd6..cc4aeb42326f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -669,6 +669,20 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"FROM $testTable GROUP BY dept ORDER BY max(name)"), Row(1, 1) :: Row(2, 2) :: Row(6, 6) :: Nil ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 FROM employee GROUP BY dept ORDER BY max(name)"), + Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 AS b " + + "FROM employee GROUP BY dept ORDER BY max(name)"), + Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " + + "FROM employee GROUP BY dept ORDER BY max(name)"), + Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil + ) // having cond is resolved by aggregate's child checkAnswer( @@ -676,6 +690,21 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"FROM $testTable GROUP BY dept HAVING max(name) = 'david'"), Row(1250, 2, 11000, 11010) :: Nil ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 " + + "FROM employee GROUP BY dept HAVING max(bonus) > 1200"), + Row(2, 11000, 11010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 AS b " + + "FROM employee GROUP BY dept HAVING max(bonus) > 1200"), + Row(2, 11000, 11010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " + + "FROM employee GROUP BY dept HAVING max(bonus) > 1200"), + Row(2, 11000, 11010) :: Nil + ) // having cond is resolved by aggregate itself checkAnswer( sql(s"SELECT avg(bonus) AS a, a FROM $testTable GROUP BY dept HAVING a > 1200"), @@ -1139,4 +1168,120 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { // non group by or non aggregate function in Aggregate queries negative cases are covered in // "Aggregate expressions not eligible to lift up, throws same error as inline". } + + test("Still resolves when Aggregate with LCA is not the direct child of Having") { + // Previously there was a limitation of lca that it can't resolve the query when it satisfies + // all the following criteria: + // 1) the main (outer) query has having clause + // 2) there is a window expression in the query + // 3) in the same SELECT list as the window expression in 2), there is an lca + // Though [UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_WITH_WINDOW_AND_HAVING] is + // still not supported, after SPARK-44714, a lot other limitations are + // lifted because it allows to resolve LCA when the query has UnresolvedHaving but its direct + // child does not contain an LCA. + // Testcases in this test focus on this change regarding enablement of resolution. + + // CTE definition contains window and LCA; outer query contains having + checkAnswer( + sql( + s""" + |with w as ( + | select name, dept, salary, rank() over (partition by dept order by salary) as r, r + | from $testTable + |) + |select dept + |from w + |group by dept + |having max(salary) > 10000 + |""".stripMargin), + Row(2) :: Row(6) :: Nil + ) + checkAnswer( + sql( + s""" + |with w as ( + | select name, dept, salary, rank() over (partition by dept order by salary) as r, r + | from $testTable + |) + |select dept as d, d + |from w + |group by dept + |having max(salary) > 10000 + |""".stripMargin), + Row(2, 2) :: Row(6, 6) :: Nil + ) + checkAnswer( + sql( + s""" + |with w as ( + | select name, dept, salary, rank() over (partition by dept order by salary) as r, r + | from $testTable + |) + |select dept as d + |from w + |group by dept + |having d = 2 + |""".stripMargin), + Row(2) :: Nil + ) + + // inner subquery contains window and LCA; outer query contains having + checkAnswer( + sql( + s""" + |SELECT + | dept + |FROM + | ( + | select + | name, dept, salary, rank() over (partition by dept order by salary) as r, + | 1 as a, a + 1 as e + | FROM + | $testTable + | ) AS inner_t + |GROUP BY + | dept + |HAVING max(salary) > 10000 + |""".stripMargin), + Row(2) :: Row(6) :: Nil + ) + checkAnswer( + sql( + s""" + |SELECT + | dept as d, d + |FROM + | ( + | select + | name, dept, salary, rank() over (partition by dept order by salary) as r, + | 1 as a, a + 1 as e + | FROM + | $testTable + | ) AS inner_t + |GROUP BY + | dept + |HAVING max(salary) > 10000 + |""".stripMargin), + Row(2, 2) :: Row(6, 6) :: Nil + ) + checkAnswer( + sql( + s""" + |SELECT + | dept as d + |FROM + | ( + | select + | name, dept, salary, rank() over (partition by dept order by salary) as r, + | 1 as a, a + 1 as e + | FROM + | $testTable + | ) AS inner_t + |GROUP BY + | dept + |HAVING d = 2 + |""".stripMargin), + Row(2) :: Nil + ) + } } From 74fa07c5702004ed2bcd83872687473122e13bab Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Tue, 8 Aug 2023 17:46:09 +0800 Subject: [PATCH 12/30] [SPARK-44236][SQL] Disable WholeStageCodegen when set `spark.sql.codegen.factoryMode` to NO_CODEGEN ### What changes were proposed in this pull request? After #41467 , we fix the `CodegenInterpretedPlanTest ` will execute codeGen even set `spark.sql.codegen.factoryMode` to `NO_CODEGEN`. Before this PR, `spark.sql.codegen.factoryMode` can't disable WholeStageCodegen, many test case want to disable codegen by set `spark.sql.codegen.factoryMode` to `NO_CODEGEN`, but it not work for WholeStageCodegen. So this PR change the `spark.sql.codegen.factoryMode` behavior, when set `NO_CODEGEN`, we will disable `WholeStageCodegen` too. ### Why are the changes needed? Fix the `spark.sql.codegen.factoryMode` config behavior. ### Does this PR introduce _any_ user-facing change? Yes, the config logic changed. ### How was this patch tested? add new test. Closes #41779 from Hisoka-X/SPARK-44236_wholecodegen_disable. Authored-by: Jia Fan Signed-off-by: Wenchen Fan --- .../CodeGeneratorWithInterpretedFallback.scala | 3 +-- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 ++ .../apache/spark/sql/catalyst/plans/PlanTest.scala | 3 +-- .../spark/sql/execution/WholeStageCodegenExec.scala | 3 ++- .../spark/sql/execution/WholeStageCodegenSuite.scala | 11 +++++++++++ 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala index 0509b852cfdde..62a1afecfd7f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -38,8 +38,7 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging { def createObject(in: IN): OUT = { // We are allowed to choose codegen-only or no-codegen modes if under tests. - val config = SQLConf.get.getConf(SQLConf.CODEGEN_FACTORY_MODE) - val fallbackMode = CodegenObjectFactoryMode.withName(config) + val fallbackMode = CodegenObjectFactoryMode.withName(SQLConf.get.codegenFactoryMode) fallbackMode match { case CodegenObjectFactoryMode.CODEGEN_ONLY => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bcf8ce2bc5407..e4f335a9a08f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4721,6 +4721,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) + def codegenFactoryMode: String = getConf(CODEGEN_FACTORY_MODE) + def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS) def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index ebf48c5f863d2..e90a956ab4fde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -47,8 +47,7 @@ trait CodegenInterpretedPlanTest extends PlanTest { super.test(testName + " (codegen path)", testTags: _*)( withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { testFun })(pos) super.test(testName + " (interpreted path)", testTags: _*)( - withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode) { - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { testFun }})(pos) + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> interpretedMode) { testFun })(pos) } protected def testFallback( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 5fc51cc6e310b..40de623f73d59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -942,7 +942,8 @@ case class CollapseCodegenStages( } def apply(plan: SparkPlan): SparkPlan = { - if (conf.wholeStageEnabled) { + if (conf.wholeStageEnabled && CodegenObjectFactoryMode.withName(conf.codegenFactoryMode) + != CodegenObjectFactoryMode.NO_CODEGEN) { insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0aaeedd5f06d1..5a413c77754f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} @@ -182,6 +183,16 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4))) } + test("SPARK-44236: disable WholeStageCodegen when set spark.sql.codegen.factoryMode is " + + "NO_CODEGEN") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) { + val df = spark.range(10).select($"id" + 1) + val plan = df.queryExecution.executedPlan + assert(!plan.exists(_.isInstanceOf[WholeStageCodegenExec])) + checkAnswer(df, 1L to 10L map { i => Row(i) }) + } + } + test("Full Outer ShuffledHashJoin and SortMergeJoin should be included in WholeStageCodegen") { val df1 = spark.range(5).select($"id".as("k1")) val df2 = spark.range(10).select($"id".as("k2")) From c46d4caa59865e9b99e02f6adc79f49f9ebc8f7f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 15:04:07 +0200 Subject: [PATCH 13/30] [SPARK-44713][CONNECT][SQL] Move shared classes to sql/api ### What changes were proposed in this pull request? This PR deduplicates the following classes: - `org.apache.spark.sql.SaveMode` - `org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction` - `org.apache.spark.api.java.function.MapGroupsWithStateFunction` - `org.apache.spark.sql.streaming.GroupState` These classes were all duplicates in the Scala Client. I have moved the original versions to `sql/api` and I removed the connect equivalents. ### Why are the changes needed? Duplication is always good :). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Compilation. Closes #42386 from hvanhovell/SPARK-44713. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../java/org/apache/spark/sql/SaveMode.java | 58 --- .../FlatMapGroupsWithStateFunction.java | 39 -- .../function/MapGroupsWithStateFunction.java | 38 -- .../spark/sql/streaming/GroupState.scala | 336 ------------------ project/MimaExcludes.scala | 13 +- .../FlatMapGroupsWithStateFunction.java | 0 .../function/MapGroupsWithStateFunction.java | 0 .../java/org/apache/spark/sql/SaveMode.java | 0 .../spark/sql/streaming/GroupState.scala | 0 9 files changed, 6 insertions(+), 478 deletions(-) delete mode 100644 connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java delete mode 100644 connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java delete mode 100644 connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java delete mode 100644 connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala rename sql/{core => api}/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java (100%) rename sql/{core => api}/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java (100%) rename sql/{core => api}/src/main/java/org/apache/spark/sql/SaveMode.java (100%) rename sql/{core => api}/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala (100%) diff --git a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java b/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java deleted file mode 100644 index 95af157687c85..0000000000000 --- a/connector/connect/client/jvm/src/main/java/org/apache/spark/sql/SaveMode.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql; - -import org.apache.spark.annotation.Stable; - -/** - * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. - * - * @since 3.4.0 - */ -@Stable -public enum SaveMode { - /** - * Append mode means that when saving a DataFrame to a data source, if data/table already exists, - * contents of the DataFrame are expected to be appended to existing data. - * - * @since 3.4.0 - */ - Append, - /** - * Overwrite mode means that when saving a DataFrame to a data source, - * if data/table already exists, existing data is expected to be overwritten by the contents of - * the DataFrame. - * - * @since 3.4.0 - */ - Overwrite, - /** - * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, - * an exception is expected to be thrown. - * - * @since 3.4.0 - */ - ErrorIfExists, - /** - * Ignore mode means that when saving a DataFrame to a data source, if data already exists, - * the save operation is expected to not save the contents of the DataFrame and to not - * change the existing data. - * - * @since 3.4.0 - */ - Ignore -} diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java deleted file mode 100644 index c917c8d28be04..0000000000000 --- a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; -import java.util.Iterator; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.streaming.GroupState; - -/** - * ::Experimental:: - * Base interface for a map function used in - * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState( - * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, - * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} - * @since 3.5.0 - */ -@Experimental -@Evolving -public interface FlatMapGroupsWithStateFunction extends Serializable { - Iterator call(K key, Iterator values, GroupState state) throws Exception; -} diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java deleted file mode 100644 index ae179ad7d276f..0000000000000 --- a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; -import java.util.Iterator; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.streaming.GroupState; - -/** - * ::Experimental:: - * Base interface for a map function used in - * {@code org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState( - * MapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} - * @since 3.5.0 - */ -@Experimental -@Evolving -public interface MapGroupsWithStateFunction extends Serializable { - R call(K key, Iterator values, GroupState state) throws Exception; -} diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala deleted file mode 100644 index bd418a89534ad..0000000000000 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.annotation.{Evolving, Experimental} -import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState - -/** - * :: Experimental :: - * - * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`. - * - * Detail description on `[map/flatMap]GroupsWithState` operation - * -------------------------------------------------------------- Both, `mapGroupsWithState` and - * `flatMapGroupsWithState` in `KeyValueGroupedDataset` will invoke the user-given function on - * each group (defined by the grouping function in `Dataset.groupByKey()`) while maintaining a - * user-defined per-group state between invocations. For a static batch Dataset, the function will - * be invoked once per group. For a streaming Dataset, the function will be invoked for each group - * repeatedly in every trigger. That is, in every batch of the `StreamingQuery`, the function will - * be invoked once for each group that has data in the trigger. Furthermore, if timeout is set, - * then the function will be invoked on timed-out groups (more detail below). - * - * The function is invoked with the following parameters. - * - The key of the group. - * - An iterator containing all the values for this group. - * - A user-defined state object set by previous invocations of the given function. - * - * In case of a batch Dataset, there is only one invocation and the state object will be empty as - * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` is - * equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have no - * effect. - * - * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the - * former allows the function to return one and only one record, whereas the latter allows the - * function to return any number of records (including no records). Furthermore, the - * `flatMapGroupsWithState` is associated with an operation output mode, which can be either - * `Append` or `Update`. Semantically, this defines whether the output records of one trigger is - * effectively replacing the previously output records (from previous triggers) or is appending to - * the list of previously output records. Essentially, this defines how the Result Table (refer to - * the semantics in the programming guide) is updated, and allows us to reason about the semantics - * of later operations. - * - * Important points to note about the function (both mapGroupsWithState and - * flatMapGroupsWithState). - * - In a trigger, the function will be called only the groups present in the batch. So do not - * assume that the function will be called in every trigger for every group that has state. - * - There is no guaranteed ordering of values in the iterator in the function, neither with - * batch, nor with streaming Datasets. - * - All the data will be shuffled before applying the function. - * - If timeout is set, then the function will also be called with no values. See more details - * on `GroupStateTimeout` below. - * - * Important points to note about using `GroupState`. - * - The value of the state cannot be null. So updating state with null will throw - * `IllegalArgumentException`. - * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, `get()` will throw - * `NoSuchElementException` and `getOption()` will return `None` - * - After that, if `update(newState)` is called, then `exists()` will again return `true`, - * `get()` and `getOption()`will return the updated value. - * - * Important points to note about using `GroupStateTimeout`. - * - The timeout type is a global param across all the groups (set as `timeout` param in - * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable - * per group by calling `setTimeout...()` in `GroupState`. - * - Timeouts can be either based on processing time (i.e. - * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. - * `GroupStateTimeout.EventTimeTimeout`). - * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling - * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the - * set duration. Guarantees provided by this timeout with a duration of D ms are as follows: - * - Timeout will never occur before the clock time has advanced by D ms - * - Timeout will occur eventually when there is a trigger in the query (i.e. after D ms). So - * there is no strict upper bound on when the timeout would occur. For example, the trigger - * interval of the query will affect when the timeout actually occurs. If there is no data - * in the stream (for any group) for a while, then there will not be any trigger and timeout - * function call will not occur until there is data. - * - Since the processing time timeout is based on the clock time, it is affected by the - * variations in the system clock (i.e. time zone changes, clock skew, etc.). - * - With `EventTimeTimeout`, the user also has to specify the event time watermark in the query - * using `Dataset.withWatermark()`. With this setting, data that is older than the watermark - * is filtered out. The timeout can be set for a group by setting a timeout timestamp - * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark - * advances beyond the set timestamp. You can control the timeout delay by two parameters - - * (i) watermark delay and an additional duration beyond the timestamp in the event (which is - * guaranteed to be newer than watermark due to the filtering). Guarantees provided by this - * timeout are as follows: - * - Timeout will never occur before the watermark has exceeded the set timeout. - * - Similar to processing time timeouts, there is no strict upper bound on the delay when the - * timeout actually occurs. The watermark can advance only when there is data in the stream - * and the event time of the data has actually advanced. - * - When the timeout occurs for a group, the function is called for that group with no values, - * and `GroupState.hasTimedOut()` set to true. - * - The timeout is reset every time the function is called on a group, that is, when the group - * has new data, or the group has timed out. So the user has to set the timeout duration every - * time the function is called, otherwise, there will not be any timeout set. - * - * `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument. - * This state will be applied when the first batch of the streaming query is processed. If there - * are no matching rows in the data for the keys present in the initial state, the state is still - * applied and the function will be invoked with the values being an empty iterator. - * - * Scala example of using GroupState in `mapGroupsWithState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. - * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = { - * - * if (state.hasTimedOut) { // If called when timing out, remove the state - * state.remove() - * - * } else if (state.exists) { // If state exists, use it for processing - * val existingState = state.get // Get the existing state - * val shouldRemove = ... // Decide whether to remove the state - * if (shouldRemove) { - * state.remove() // Remove the state - * - * } else { - * val newState = ... - * state.update(newState) // Set the new state - * state.setTimeoutDuration("1 hour") // Set the timeout - * } - * - * } else { - * val initialState = ... - * state.update(initialState) // Set the initial state - * state.setTimeoutDuration("1 hour") // Set the timeout - * } - * ... - * // return something - * } - * - * dataset - * .groupByKey(...) - * .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction) - * }}} - * - * Java example of using `GroupState`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. - * MapGroupsWithStateFunction mappingFunction = - * new MapGroupsWithStateFunction() { - * - * @Override - * public String call(String key, Iterator value, GroupState state) { - * if (state.hasTimedOut()) { // If called when timing out, remove the state - * state.remove(); - * - * } else if (state.exists()) { // If state exists, use it for processing - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * state.setTimeoutDuration("1 hour"); // Set the timeout - * } - * - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * state.setTimeoutDuration("1 hour"); // Set the timeout - * } - * ... - * // return something - * } - * }; - * - * dataset - * .groupByKey(...) - * .mapGroupsWithState( - * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout); - * }}} - * - * @tparam S - * User-defined type of the state to be stored for each group. Must be encodable into Spark SQL - * types (see `Encoder` for more details). - * @since 3.5.0 - */ -@Experimental -@Evolving -trait GroupState[S] extends LogicalGroupState[S] { - - /** Whether state exists or not. */ - def exists: Boolean - - /** Get the state value if it exists, or throw NoSuchElementException. */ - @throws[NoSuchElementException]("when state does not exist") - def get: S - - /** Get the state value as a scala Option. */ - def getOption: Option[S] - - /** Update the value of the state. */ - def update(newState: S): Unit - - /** Remove this state. */ - def remove(): Unit - - /** - * Whether the function has been called because the key has timed out. - * @note - * This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. - */ - def hasTimedOut: Boolean - - /** - * Set the timeout duration in ms for this key. - * - * @note - * [[GroupStateTimeout Processing time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note - * This method has no effect when used in a batch query. - */ - @throws[IllegalArgumentException]("if 'durationMs' is not positive") - @throws[UnsupportedOperationException]( - "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") - def setTimeoutDuration(durationMs: Long): Unit - - /** - * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. - * - * @note - * [[GroupStateTimeout Processing time timeout]] must be enabled in - * `[map/flatMap]GroupsWithState` for calling this method. - * @note - * This method has no effect when used in a batch query. - */ - @throws[IllegalArgumentException]("if 'duration' is not a valid duration") - @throws[UnsupportedOperationException]( - "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") - def setTimeoutDuration(duration: String): Unit - - /** - * Set the timeout timestamp for this key as milliseconds in epoch time. This timestamp cannot - * be older than the current watermark. - * - * @note - * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` - * for calling this method. - * @note - * This method has no effect when used in a batch query. - */ - @throws[IllegalArgumentException]( - "if 'timestampMs' is not positive or less than the current watermark in a streaming query") - @throws[UnsupportedOperationException]( - "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") - def setTimeoutTimestamp(timestampMs: Long): Unit - - /** - * Set the timeout timestamp for this key as milliseconds in epoch time and an additional - * duration as a string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the - * additional duration) cannot be older than the current watermark. - * - * @note - * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` - * for calling this method. - * @note - * This method has no side effect when used in a batch query. - */ - @throws[IllegalArgumentException]( - "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + - "the current watermark in a streaming query") - @throws[UnsupportedOperationException]( - "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") - def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit - - /** - * Set the timeout timestamp for this key as a java.sql.Date. This timestamp cannot be older - * than the current watermark. - * - * @note - * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` - * for calling this method. - * @note - * This method has no side effect when used in a batch query. - */ - @throws[UnsupportedOperationException]( - "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") - def setTimeoutTimestamp(timestamp: java.sql.Date): Unit - - /** - * Set the timeout timestamp for this key as a java.sql.Date and an additional duration as a - * string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the additional - * duration) cannot be older than the current watermark. - * - * @note - * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` - * for calling this method. - * @note - * This method has no side effect when used in a batch query. - */ - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[UnsupportedOperationException]( - "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") - def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit - - /** - * Get the current event time watermark as milliseconds in epoch time. - * - * @note - * In a streaming query, this can be called only when watermark is set before calling - * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. - */ - @throws[UnsupportedOperationException]( - "if watermark has not been set before in [map|flatMap]GroupsWithState") - def getCurrentWatermarkMs(): Long - - /** - * Get the current processing time as milliseconds in epoch time. - * @note - * In a streaming query, this will return a constant value throughout the duration of a - * trigger, even if the trigger is re-executed. - */ - def getCurrentProcessingTimeMs(): Long -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8da132f5de3c5..c1e1d08759028 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -58,12 +58,6 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listTables"), // [SPARK-43992][SQL][PYTHON][CONNECT] Add optional pattern for Catalog.listFunctions ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listFunctions"), - // [SPARK-43919][SQL] Extract JSON functionality out of Row - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Row.json"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Row.prettyJson"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.json"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.prettyJson"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.MutableAggregationBuffer.jsonValue"), // [SPARK-43952][CORE][CONNECT][SQL] Add SparkContext APIs for query cancellation by tag ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), // [SPARK-44205][SQL] Extract Catalyst Code from DecimalType @@ -79,7 +73,12 @@ object MimaExcludes { // [SPARK-44198][CORE] Support propagation of the log level to the executors ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"), // [SPARK-44692][CONNECT][SQL] Move Trigger(s) to sql/api - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.Trigger") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.Trigger"), + // [SPARK-44713][CONNECT][SQL] Move shared classes to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.api.java.function.MapGroupsWithStateFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SaveMode"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState") ) // Default exclude rules diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/api/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java rename to sql/api/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/api/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java rename to sql/api/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/api/src/main/java/org/apache/spark/sql/SaveMode.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/SaveMode.java rename to sql/api/src/main/java/org/apache/spark/sql/SaveMode.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala From 93af0848e467fe4d58c0fb1242b738931390d6f8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 15:05:18 +0200 Subject: [PATCH 14/30] [SPARK-44710][CONNECT] Add Dataset.dropDuplicatesWithinWatermark to Spark Connect Scala Client ### What changes were proposed in this pull request? This PR adds `Dataset.dropDuplicatesWithinWatermark` to the Spark Connect Scala Client. ### Why are the changes needed? Increase compatibility with the current sql/core APIs. ### Does this PR introduce _any_ user-facing change? Yes. It adds a new method to the scala client. ### How was this patch tested? Added a new (rudimentary) test to `ClientStreamingQuerySuite`. Closes #42384 from hvanhovell/SPARK-44710. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++++++++++-------- .../apache/spark/sql/ClientE2ETestSuite.scala | 20 +++++++++ .../query-tests/queries/distinct.json | 3 +- .../query-tests/queries/distinct.proto.bin | Bin 50 -> 52 bytes .../query-tests/queries/dropDuplicates.json | 3 +- .../queries/dropDuplicates.proto.bin | Bin 50 -> 52 bytes .../queries/dropDuplicates_names_array.json | 3 +- .../dropDuplicates_names_array.proto.bin | Bin 55 -> 57 bytes .../queries/dropDuplicates_names_seq.json | 3 +- .../dropDuplicates_names_seq.proto.bin | Bin 54 -> 56 bytes .../queries/dropDuplicates_varargs.json | 3 +- .../queries/dropDuplicates_varargs.proto.bin | Bin 58 -> 60 bytes 12 files changed, 51 insertions(+), 23 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 8a7dce3987a44..5f263903c8bbc 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2399,6 +2399,19 @@ class Dataset[T] private[sql] ( .addAllColumnNames(cols.asJava) } + private def buildDropDuplicates( + columns: Option[Seq[String]], + withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(encoder) { builder => + val dropBuilder = builder.getDeduplicateBuilder + .setInput(plan.getRoot) + .setWithinWatermark(withinWaterMark) + if (columns.isDefined) { + dropBuilder.addAllColumnNames(columns.get.asJava) + } else { + dropBuilder.setAllColumnsAsKeys(true) + } + } + /** * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias * for `distinct`. @@ -2406,11 +2419,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder => - builder.getDeduplicateBuilder - .setInput(plan.getRoot) - .setAllColumnsAsKeys(true) - } + def dropDuplicates(): Dataset[T] = buildDropDuplicates(None, withinWaterMark = false) /** * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the @@ -2419,11 +2428,8 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) { - builder => - builder.getDeduplicateBuilder - .setInput(plan.getRoot) - .addAllColumnNames(colNames.asJava) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = { + buildDropDuplicates(Option(colNames), withinWaterMark = false) } /** @@ -2443,16 +2449,14 @@ class Dataset[T] private[sql] ( */ @scala.annotation.varargs def dropDuplicates(col1: String, cols: String*): Dataset[T] = { - val colNames: Seq[String] = col1 +: cols - dropDuplicates(colNames) + dropDuplicates(col1 +: cols) } - def dropDuplicatesWithinWatermark(): Dataset[T] = { - dropDuplicatesWithinWatermark(this.columns) - } + def dropDuplicatesWithinWatermark(): Dataset[T] = + buildDropDuplicates(None, withinWaterMark = true) def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = { - throw new UnsupportedOperationException("dropDuplicatesWithinWatermark is not implemented.") + buildDropDuplicates(Option(colNames), withinWaterMark = true) } def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = { @@ -2461,8 +2465,7 @@ class Dataset[T] private[sql] ( @scala.annotation.varargs def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = { - val colNames: Seq[String] = col1 +: cols - dropDuplicatesWithinWatermark(colNames) + dropDuplicatesWithinWatermark(col1 +: cols) } /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index ebd3d037bba5c..074cf170dd39d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -1183,6 +1183,26 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner") checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined) } + + test("dropDuplicatesWithinWatermark not supported in batch DataFrame") { + def testAndVerify(df: Dataset[_]): Unit = { + val exc = intercept[AnalysisException] { + df.write.format("noop").mode(SaveMode.Append).save() + } + + assert(exc.getMessage.contains("dropDuplicatesWithinWatermark is not supported")) + assert(exc.getMessage.contains("batch DataFrames/DataSets")) + } + + val result = spark.range(10).dropDuplicatesWithinWatermark() + testAndVerify(result) + + val result2 = spark + .range(10) + .withColumn("newcol", col("id")) + .dropDuplicatesWithinWatermark("newcol") + testAndVerify(result2) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/common/src/test/resources/query-tests/queries/distinct.json b/connector/connect/common/src/test/resources/query-tests/queries/distinct.json index ae796b520353c..15c320d462b31 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/distinct.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/distinct.json @@ -11,6 +11,7 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "allColumnsAsKeys": true + "allColumnsAsKeys": true, + "withinWatermark": false } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/distinct.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/distinct.proto.bin index 07430c43831060375665ce94cd79d8496286af57..078223b1f3edd6cd65da4da9c1156c873f004ec4 100644 GIT binary patch delta 15 WcmXpqVdG*FU@X#`$fm=jzyJUcrUDND delta 12 TcmXppV&h^GU@X#^$fg4T3X=ir diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates.json b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates.json index ae796b520353c..15c320d462b31 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates.json @@ -11,6 +11,7 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "allColumnsAsKeys": true + "allColumnsAsKeys": true, + "withinWatermark": false } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates.proto.bin index 07430c43831060375665ce94cd79d8496286af57..078223b1f3edd6cd65da4da9c1156c873f004ec4 100644 GIT binary patch delta 15 WcmXpqVdG*FU@X#`$fm=jzyJUcrUDND delta 12 TcmXppV&h^GU@X#^$fg4T3X=ir diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_array.json b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_array.json index e72e23c86caf0..23df6972a517b 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_array.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_array.json @@ -11,6 +11,7 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "columnNames": ["a", "id"] + "columnNames": ["a", "id"], + "withinWatermark": false } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_array.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_array.proto.bin index c8e3885fbf804423437ffc4a9e33a3a4f90d788f..3bdbeb0d3863702dc7d21689c0ed7bff925872cd 100644 GIT binary patch delta 15 WcmXrFWaDBIU@S78$Y#i-zyJUdZUQ3! delta 12 TcmcCCXX9cLU@S76$Yuxt3m*ah diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_seq.json b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_seq.json index 754cecac4b256..6ef72770b9a63 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_seq.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_seq.json @@ -11,6 +11,7 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "columnNames": ["a", "b"] + "columnNames": ["a", "b"], + "withinWatermark": false } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_seq.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_names_seq.proto.bin index 1a2d635e58e56cca21c63807141260af3069f07b..65b4942c5684e4400d46528c431ccd7c345127ff 100644 GIT binary patch delta 15 WcmXrBVB=yEU@S74$Y#K#zyJUdMgkuI delta 12 TcmcC8W8-2HU@S72$YuZl3j+cB diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_varargs.json b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_varargs.json index c4a8df30c5867..2b6d46a313513 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_varargs.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_varargs.json @@ -11,6 +11,7 @@ "schema": "struct\u003cid:bigint,a:int,b:double\u003e" } }, - "columnNames": ["a", "b", "id"] + "columnNames": ["a", "b", "id"], + "withinWatermark": false } } \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_varargs.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/dropDuplicates_varargs.proto.bin index 719a373c2e3843693dbab093d9ebf652d5914465..57f0d7e5afa6733d2de656e92f033e311022c8b7 100644 GIT binary patch delta 15 WcmcDrVdG*FU@S75$Y#Q%zyJUd<^nGO delta 12 TcmcDqV&h^GU@S73$Yufn3v&Vs From 8c444f497137d5abb3a94b576ec0fea55dc18bbc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 15:41:36 +0200 Subject: [PATCH 15/30] [SPARK-44715][CONNECT] Bring back callUdf and udf function ### What changes were proposed in this pull request? This PR adds the `udf` (with a return type), and `callUDF` functions to `functions.scala` for the Spark Connect Scala Client. ### Why are the changes needed? We want the Spark Connect Scala Client to be as compatible as possible with the existing sql/core APIs. ### Does this PR introduce _any_ user-facing change? Yes. It adds more exposed functions. ### How was this patch tested? Added tests to `UserDefinedFunctionE2ETestSuite` and `FunctionTestSuite`. I have also updated the compatibility checks. Closes #42387 from hvanhovell/SPARK-44715. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/functions.scala | 40 +++++++++++++++++++ .../apache/spark/sql/FunctionTestSuite.scala | 2 + .../sql/UserDefinedFunctionE2ETestSuite.scala | 20 ++++++++++ .../CheckConnectJvmClientCompatibility.scala | 7 ---- 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 89bfc99817948..fa8c5782e0614 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -8056,6 +8056,46 @@ object functions { } // scalastyle:off line.size.limit + /** + * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, + * the caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * Note that, although the Scala closure can have primitive-type function argument, it doesn't + * work well with null values. Because the Scala closure is passed in as Any type, there is no + * type information for the function arguments. Without the type information, Spark may blindly + * pass null to the Scala closure with primitive-type argument, and the closure will see the + * default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`, + * the result is 0 for null input. + * + * @param f + * A closure in Scala + * @param dataType + * The output data type of the UDF + * + * @group udf_funcs + * @since 3.5.0 + */ + @deprecated( + "Scala `udf` method with return type parameter is deprecated. " + + "Please use Scala `udf` method without return type parameter.", + "3.0.0") + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(f, dataType) + } + + /** + * Call an user-defined function. + * + * @group udf_funcs + * @since 3.5.0 + */ + @scala.annotation.varargs + @deprecated("Use call_udf") + def callUDF(udfName: String, cols: Column*): Column = + call_function(udfName, cols: _*) + /** * Call an user-defined function. Example: * {{{ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 32004b6bcc11d..4a8e108357fa7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -249,6 +249,8 @@ class FunctionTestSuite extends ConnectFunSuite { pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes(), Map.empty[String, String].asJava), pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes())) + testEquals("call_udf", callUDF("bob", lit(1)), call_udf("bob", lit(1))) + test("assert_true no message") { val e = assert_true(a).expr assert(e.hasUnresolvedFunction) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index 258fa1e7c740a..3a931c9a6ba80 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -24,9 +24,11 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ import org.apache.spark.api.java.function._ +import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions.{col, struct, udf} +import org.apache.spark.sql.types.IntegerType /** * All tests in this class requires client UDF defined in this test class synced with the server. @@ -250,4 +252,22 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { "b", "c") } + + test("(deprecated) scala UDF with dataType") { + val session: SparkSession = spark + import session.implicits._ + val fn = udf(((i: Long) => (i + 1).toInt), IntegerType) + checkDataset(session.range(2).select(fn($"id")).as[Int], 1, 2) + } + + test("java UDF") { + val session: SparkSession = spark + import session.implicits._ + val fn = udf( + new UDF2[Long, Long, Int] { + override def call(t1: Long, t2: Long): Int = (t1 + t2 + 1).toInt + }, + IntegerType) + checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 2bf9c41fb2cbd..d380a1bbb653e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -191,8 +191,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"), // functions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), @@ -214,14 +212,11 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"), ProblemFilters.exclude[Problem]( "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), - // TODO(SPARK-44068): Support positional parameters in Scala connect client - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"), // SparkSession#implicits @@ -266,8 +261,6 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.StreamingQueryException.time"), // Classes missing from streaming API - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.TestGroupState"), ProblemFilters.exclude[MissingClassProblem]( From 418bba5ad6053449a141f3c9c31ed3ad998995b8 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Tue, 8 Aug 2023 18:32:25 +0200 Subject: [PATCH 16/30] [SPARK-44709][CONNECT] Run ExecuteGrpcResponseSender in reattachable execute in new thread to fix flow control ### What changes were proposed in this pull request? If executePlan / reattachExecute handling is done directly on the GRPC thread, flow control OnReady events are getting queued until after the handler returns, so OnReadyHandler never gets notified until after the handler exits. The correct way to use it is for the handler to delegate work to another thread and exit. See https://github.com/grpc/grpc-java/issues/7361 Tidied up and added a lot of logging and statistics to ExecuteGrpcResponseSender and ExecuteResponseObserver to be able to observe this behaviour. Followup work in https://issues.apache.org/jira/browse/SPARK-44625 is needed for cleanup of abandoned executions that will also make sure that these threads are joined. ### Why are the changes needed? ExecuteGrpcResponseSender gets stuck waiting on grpcCallObserverReadySignal because events from OnReadyHandler do not arrive. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added extensive debugging to ExecuteGrpcResponseSender and ExecuteResponseObserver and tested and observer the behaviour of all the threads. Closes #42355 from juliuszsompolski/spark-rpc-extra-thread. Authored-by: Juliusz Sompolski Signed-off-by: Herman van Hovell --- .../spark/sql/connect/config/Connect.scala | 13 +- .../execution/CachedStreamResponse.scala | 2 + .../execution/ExecuteGrpcResponseSender.scala | 164 +++++++++++++----- .../execution/ExecuteResponseObserver.scala | 116 +++++++++++-- .../execution/ExecuteThreadRunner.scala | 3 +- .../sql/connect/service/ExecuteHolder.scala | 21 ++- .../SparkConnectExecutePlanHandler.scala | 20 +-- .../SparkConnectReattachExecuteHandler.scala | 22 +-- 8 files changed, 264 insertions(+), 97 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index e25cb5cbab279..0be53064cc040 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -74,6 +74,17 @@ object Connect { .intConf .createWithDefault(1024) + val CONNECT_EXECUTE_REATTACHABLE_ENABLED = + ConfigBuilder("spark.connect.execute.reattachable.enabled") + .internal() + .doc("Enables reattachable execution on the server. If disabled and a client requests it, " + + "non-reattachable execution will follow and should run until query completion. This will " + + "work, unless there is a GRPC stream error, in which case the client will discover that " + + "execution is not reattachable when trying to reattach fails.") + .version("3.5.0") + .booleanConf + .createWithDefault(true) + val CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION = ConfigBuilder("spark.connect.execute.reattachable.senderMaxStreamDuration") .internal() @@ -82,7 +93,7 @@ object Connect { "Set to 0 for unlimited.") .version("3.5.0") .timeConf(TimeUnit.MILLISECONDS) - .createWithDefaultString("5m") + .createWithDefaultString("2m") val CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE = ConfigBuilder("spark.connect.execute.reattachable.senderMaxStreamSize") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala index ec9fce785badb..a2bbe14f2014c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/CachedStreamResponse.scala @@ -22,6 +22,8 @@ import com.google.protobuf.MessageLite private[execution] case class CachedStreamResponse[T <: MessageLite]( // the actual cached response response: T, + // the id of the response, an UUID. + responseId: String, // index of the response in the response stream. // responses produced in the stream are numbered consecutively starting from 1. streamIndex: Long) { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index 88124080ccad5..7b51a90ca3741 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.connect.execution -import com.google.protobuf.MessageLite +import com.google.protobuf.Message import io.grpc.stub.{ServerCallStreamObserver, StreamObserver} import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE} import org.apache.spark.sql.connect.service.ExecuteHolder /** - * ExecuteGrpcResponseSender sends responses to the GRPC stream. It runs on the RPC thread, and - * gets notified by ExecuteResponseObserver about available responses. It notifies the - * ExecuteResponseObserver back about cached responses that can be removed after being sent out. + * ExecuteGrpcResponseSender sends responses to the GRPC stream. It consumes responses from + * ExecuteResponseObserver and sends them out as responses to ExecutePlan or ReattachExecute. * @param executeHolder * The execution this sender attaches to. * @param grpcObserver * the GRPC request StreamObserver */ -private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( +private[connect] class ExecuteGrpcResponseSender[T <: Message]( val executeHolder: ExecuteHolder, grpcObserver: StreamObserver[T]) extends Logging { @@ -47,15 +47,69 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( // Signal to wake up when grpcCallObserver.isReady() private val grpcCallObserverReadySignal = new Object + // Stats + private var consumeSleep = 0L + private var sendSleep = 0L + /** * Detach this sender from executionObserver. Called only from executionObserver that this - * sender is attached to. executionObserver holds lock, and needs to notify after this call. + * sender is attached to. Lock on executionObserver is held, and notifyAll will wake up this + * sender if sleeping. */ - def detach(): Unit = { + def detach(): Unit = executionObserver.synchronized { if (detached == true) { throw new IllegalStateException("ExecuteGrpcResponseSender already detached!") } detached = true + executionObserver.notifyAll() + } + + def run(lastConsumedStreamIndex: Long): Unit = { + if (executeHolder.reattachable) { + // In reattachable execution, check if grpcObserver is ready for sending, by using + // setOnReadyHandler of the ServerCallStreamObserver. Otherwise, calling grpcObserver.onNext + // can queue the responses without sending them, and it is unknown how far behind it is, and + // hence how much the executionObserver needs to buffer. + // + // Because OnReady events get queued on the same GRPC inboud queue as the executePlan or + // reattachExecute RPC handler that this is executing in, OnReady events will not arrive and + // not trigger the OnReadyHandler unless this thread returns from executePlan/reattachExecute. + // Therefore, we launch another thread to operate on the grpcObserver and send the responses, + // while this thread will exit from the executePlan/reattachExecute call, allowing GRPC + // to send the OnReady events. + // See https://github.com/grpc/grpc-java/issues/7361 + + val t = new Thread( + s"SparkConnectGRPCSender_" + + s"opId=${executeHolder.operationId}_startIndex=$lastConsumedStreamIndex") { + override def run(): Unit = { + execute(lastConsumedStreamIndex) + } + } + executeHolder.grpcSenderThreads += t + + val grpcCallObserver = grpcObserver.asInstanceOf[ServerCallStreamObserver[T]] + grpcCallObserver.setOnReadyHandler(() => { + logTrace(s"Stream ready, notify grpcCallObserverReadySignal.") + grpcCallObserverReadySignal.synchronized { + grpcCallObserverReadySignal.notifyAll() + } + }) + + // Start the thread and exit + t.start() + } else { + // Non reattachable execute runs directly in the GRPC thread. + try { + execute(lastConsumedStreamIndex) + } finally { + if (!executeHolder.reattachable) { + // Non reattachable executions release here immediately. + // (Reattachable executions release with ReleaseExecute RPC.) + executeHolder.close() + } + } + } } /** @@ -70,26 +124,16 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( * the last index that was already consumed and sent. This sender will start from index after * that. 0 means start from beginning (since first response has index 1) */ - def run(lastConsumedStreamIndex: Long): Unit = { - logDebug( - s"GrpcResponseSender run for $executeHolder, " + + def execute(lastConsumedStreamIndex: Long): Unit = { + logInfo( + s"Starting for opId=${executeHolder.operationId}, " + s"reattachable=${executeHolder.reattachable}, " + s"lastConsumedStreamIndex=$lastConsumedStreamIndex") + val startTime = System.nanoTime() // register to be notified about available responses. executionObserver.attachConsumer(this) - // In reattachable execution, we check if grpcCallObserver is ready for sending. - // See sendResponse - if (executeHolder.reattachable) { - val grpcCallObserver = grpcObserver.asInstanceOf[ServerCallStreamObserver[T]] - grpcCallObserver.setOnReadyHandler(() => { - grpcCallObserverReadySignal.synchronized { - grpcCallObserverReadySignal.notifyAll() - } - }) - } - var nextIndex = lastConsumedStreamIndex + 1 var finished = false @@ -129,30 +173,38 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( // Get next available response. // Wait until either this sender got detached or next response is ready, // or the stream is complete and it had already sent all responses. - logDebug(s"Trying to get next response with index=$nextIndex.") + logTrace(s"Trying to get next response with index=$nextIndex.") executionObserver.synchronized { - logDebug(s"Acquired executionObserver lock.") + logTrace(s"Acquired executionObserver lock.") + val sleepStart = System.nanoTime() + var sleepEnd = 0L while (!detachedFromObserver && !gotResponse && !streamFinished && !deadlineLimitReached) { - logDebug(s"Try to get response with index=$nextIndex from observer.") + logTrace(s"Try to get response with index=$nextIndex from observer.") response = executionObserver.consumeResponse(nextIndex) - logDebug(s"Response index=$nextIndex from observer: ${response.isDefined}") + logTrace(s"Response index=$nextIndex from observer: ${response.isDefined}") // If response is empty, release executionObserver lock and wait to get notified. // The state of detached, response and lastIndex are change under lock in // executionObserver, and will notify upon state change. if (response.isEmpty) { val timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis()) - logDebug(s"Wait for response to become available with timeout=$timeout ms.") + logTrace(s"Wait for response to become available with timeout=$timeout ms.") executionObserver.wait(timeout) - logDebug(s"Reacquired executionObserver lock after waiting.") + logTrace(s"Reacquired executionObserver lock after waiting.") + sleepEnd = System.nanoTime() } } - logDebug( - s"Exiting loop: detached=$detached, response=$response, " + + logTrace( + s"Exiting loop: detached=$detached, " + + s"response=${response.map(r => ProtoUtils.abbreviate(r.response))}, " + s"lastIndex=${executionObserver.getLastResponseIndex()}, " + s"deadline=${deadlineLimitReached}") + if (sleepEnd > 0) { + consumeSleep += sleepEnd - sleepStart + logTrace(s"Slept waiting for execution stream for ${sleepEnd - sleepStart}ns.") + } } // Process the outcome of the inner loop. @@ -160,24 +212,30 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( // This sender got detached by the observer. // This only happens if this RPC is actually dead, and the client already came back with // a ReattachExecute RPC. Kill this RPC. - logDebug(s"Detached from observer at index ${nextIndex - 1}. Complete stream.") + logWarning( + s"Got detached from opId=${executeHolder.operationId} at index ${nextIndex - 1}." + + s"totalTime=${System.nanoTime - startTime}ns " + + s"waitingForResults=${consumeSleep}ns waitingForSend=${sendSleep}ns") throw new SparkSQLException(errorClass = "INVALID_CURSOR.DISCONNECTED", Map.empty) } else if (gotResponse) { // There is a response available to be sent. - val sent = sendResponse(response.get.response, deadlineTimeMillis) + val sent = sendResponse(response.get, deadlineTimeMillis) if (sent) { - logDebug(s"Sent response index=$nextIndex.") sentResponsesSize += response.get.serializedByteSize nextIndex += 1 assert(finished == false) } else { - // If it wasn't sent, time deadline must have been reached before stream became available. + // If it wasn't sent, time deadline must have been reached before stream became available, + // will exit in the enxt loop iterattion. assert(deadlineLimitReached) - finished = true } } else if (streamFinished) { // Stream is finished and all responses have been sent - logDebug(s"Stream finished and sent all responses up to index ${nextIndex - 1}.") + logInfo( + s"Stream finished for opId=${executeHolder.operationId}, " + + s"sent all responses up to last index ${nextIndex - 1}. " + + s"totalTime=${System.nanoTime - startTime}ns " + + s"waitingForResults=${consumeSleep}ns waitingForSend=${sendSleep}ns") executionObserver.getError() match { case Some(t) => grpcObserver.onError(t) case None => grpcObserver.onCompleted() @@ -186,7 +244,11 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( } else if (deadlineLimitReached) { // The stream is not complete, but should be finished now. // The client needs to reattach with ReattachExecute. - logDebug(s"Deadline reached, finishing stream after index ${nextIndex - 1}.") + logInfo( + s"Deadline reached, shutting down stream for opId=${executeHolder.operationId} " + + s"after index ${nextIndex - 1}. " + + s"totalTime=${System.nanoTime - startTime}ns " + + s"waitingForResults=${consumeSleep}ns waitingForSend=${sendSleep}ns") grpcObserver.onCompleted() finished = true } @@ -205,10 +267,15 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( * @return * true if the response was sent, false otherwise (meaning deadline passed) */ - private def sendResponse(response: T, deadlineTimeMillis: Long): Boolean = { + private def sendResponse( + response: CachedStreamResponse[T], + deadlineTimeMillis: Long): Boolean = { if (!executeHolder.reattachable) { // no flow control in non-reattachable execute - grpcObserver.onNext(response) + logDebug( + s"SEND opId=${executeHolder.operationId} responseId=${response.responseId} " + + s"idx=${response.streamIndex} (no flow control)") + grpcObserver.onNext(response.response) true } else { // In reattachable execution, we control the flow, and only pass the response to the @@ -225,19 +292,28 @@ private[connect] class ExecuteGrpcResponseSender[T <: MessageLite]( val grpcCallObserver = grpcObserver.asInstanceOf[ServerCallStreamObserver[T]] grpcCallObserverReadySignal.synchronized { - logDebug(s"Acquired grpcCallObserverReadySignal lock.") + logTrace(s"Acquired grpcCallObserverReadySignal lock.") + val sleepStart = System.nanoTime() + var sleepEnd = 0L while (!grpcCallObserver.isReady() && deadlineTimeMillis >= System.currentTimeMillis()) { val timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis()) - logDebug(s"Wait for grpcCallObserver to become ready with timeout=$timeout ms.") + var sleepStart = System.nanoTime() + logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeout ms.") grpcCallObserverReadySignal.wait(timeout) - logDebug(s"Reacquired grpcCallObserverReadySignal lock after waiting.") + logTrace(s"Reacquired grpcCallObserverReadySignal lock after waiting.") + sleepEnd = System.nanoTime() } if (grpcCallObserver.isReady()) { - logDebug(s"grpcCallObserver is ready, sending response.") - grpcCallObserver.onNext(response) + val sleepTime = if (sleepEnd > 0L) sleepEnd - sleepStart else 0L + logDebug( + s"SEND opId=${executeHolder.operationId} responseId=${response.responseId} " + + s"idx=${response.streamIndex}" + + s"(waiting ${sleepTime}ns for GRPC stream to be ready)") + sendSleep += sleepTime + grpcCallObserver.onNext(response.response) true } else { - logDebug(s"grpcCallObserver is not ready, exiting.") + logTrace(s"grpcCallObserver is not ready, exiting.") false } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala index 8af0f72b8dafe..0573f7b3dae8d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteResponseObserver.scala @@ -21,13 +21,13 @@ import java.util.UUID import scala.collection.mutable -import com.google.protobuf.MessageLite +import com.google.protobuf.Message import io.grpc.stub.StreamObserver import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE +import org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE import org.apache.spark.sql.connect.service.ExecuteHolder /** @@ -47,7 +47,7 @@ import org.apache.spark.sql.connect.service.ExecuteHolder * @see * attachConsumer */ -private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHolder: ExecuteHolder) +private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder: ExecuteHolder) extends StreamObserver[T] with Logging { @@ -85,12 +85,18 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold */ private var responseSender: Option[ExecuteGrpcResponseSender[T]] = None + // Statistics about cached responses. + private var cachedSizeUntilHighestConsumed = CachedSize() + private var cachedSizeUntilLastProduced = CachedSize() + private var autoRemovedSize = CachedSize() + private var totalSize = CachedSize() + /** * Total size of response to be held buffered after giving out with getResponse. 0 for none, any * value greater than 0 will buffer the response from getResponse. */ private val retryBufferSize = if (executeHolder.reattachable) { - SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE).toLong + SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE).toLong } else { 0 } @@ -101,11 +107,19 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold } lastProducedIndex += 1 val processedResponse = setCommonResponseFields(r) - responses += - ((lastProducedIndex, CachedStreamResponse[T](processedResponse, lastProducedIndex))) - responseIndexToId += ((lastProducedIndex, getResponseId(processedResponse))) - responseIdToIndex += ((getResponseId(processedResponse), lastProducedIndex)) - logDebug(s"Saved response with index=$lastProducedIndex") + val responseId = getResponseId(processedResponse) + val response = CachedStreamResponse[T](processedResponse, responseId, lastProducedIndex) + + responses += ((lastProducedIndex, response)) + responseIndexToId += ((lastProducedIndex, responseId)) + responseIdToIndex += ((responseId, lastProducedIndex)) + + cachedSizeUntilLastProduced.add(response) + totalSize.add(response) + + logDebug( + s"Execution opId=${executeHolder.operationId} produced response " + + s"responseId=${responseId} idx=$lastProducedIndex") notifyAll() } @@ -115,7 +129,9 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold } error = Some(t) finalProducedIndex = Some(lastProducedIndex) // no responses to be send after error. - logDebug(s"Error. Last stream index is $lastProducedIndex.") + logDebug( + s"Execution opId=${executeHolder.operationId} produced error. " + + s"Last stream index is $lastProducedIndex.") notifyAll() } @@ -124,18 +140,17 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold throw new IllegalStateException("Stream onCompleted can't be called after stream completed") } finalProducedIndex = Some(lastProducedIndex) - logDebug(s"Completed. Last stream index is $lastProducedIndex.") + logDebug( + s"Execution opId=${executeHolder.operationId} completed stream. " + + s"Last stream index is $lastProducedIndex.") notifyAll() } /** Attach a new consumer (ExecuteResponseGRPCSender). */ def attachConsumer(newSender: ExecuteGrpcResponseSender[T]): Unit = synchronized { // detach the current sender before attaching new one - // this.synchronized() needs to be held while detaching a sender, and the detached sender - // needs to be notified with notifyAll() afterwards. responseSender.foreach(_.detach()) responseSender = Some(newSender) - notifyAll() // consumer } /** @@ -150,9 +165,18 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold assert(index <= highestConsumedIndex + 1) val ret = responses.get(index) if (ret.isDefined) { - if (index > highestConsumedIndex) highestConsumedIndex = index + if (index > highestConsumedIndex) { + highestConsumedIndex = index + cachedSizeUntilHighestConsumed.add(ret.get) + } // When the response is consumed, figure what previous responses can be uncached. - removeCachedResponses(index) + // (We keep at least one response before the one we send to consumer now) + removeCachedResponses(index - 1) + logDebug( + s"CONSUME opId=${executeHolder.operationId} responseId=${ret.get.responseId} " + + s"idx=$index. size=${ret.get.serializedByteSize} " + + s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " + + s"cachedUntilProduced=$cachedSizeUntilLastProduced") } else if (index <= highestConsumedIndex) { // If index is <= highestConsumedIndex and not available, it was already removed from cache. // This may happen if ReattachExecute is too late and the cached response was evicted. @@ -191,6 +215,25 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold def removeResponsesUntilId(responseId: String): Unit = synchronized { val index = getResponseIndexById(responseId) removeResponsesUntilIndex(index) + logDebug( + s"RELEASE opId=${executeHolder.operationId} until " + + s"responseId=$responseId " + + s"idx=$index. " + + s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " + + s"cachedUntilProduced=$cachedSizeUntilLastProduced") + } + + /** Remove all cached responses */ + def removeAll(): Unit = synchronized { + removeResponsesUntilIndex(lastProducedIndex) + logInfo( + s"Release all for opId=${executeHolder.operationId}. Execution stats: " + + s"total=${totalSize} " + + s"autoRemoved=${autoRemovedSize} " + + s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " + + s"cachedUntilProduced=$cachedSizeUntilLastProduced " + + s"maxCachedUntilConsumed=${cachedSizeUntilHighestConsumed.max} " + + s"maxCachedUntilProduced=${cachedSizeUntilLastProduced.max}") } /** Returns if the stream is finished. */ @@ -218,16 +261,31 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold totalResponsesSize += responses.get(i).get.serializedByteSize i -= 1 } - removeResponsesUntilIndex(i) + if (responses.get(i).isDefined) { + logDebug( + s"AUTORELEASE opId=${executeHolder.operationId} until idx=$i. " + + s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " + + s"cachedUntilProduced=$cachedSizeUntilLastProduced") + removeResponsesUntilIndex(i, true) + } else { + logDebug( + s"NO AUTORELEASE opId=${executeHolder.operationId}. " + + s"cachedUntilConsumed=$cachedSizeUntilHighestConsumed " + + s"cachedUntilProduced=$cachedSizeUntilLastProduced") + } } /** * Remove cached responses until given index. Iterating backwards, once an index is encountered * that has been removed, all earlier indexes would also be removed. */ - private def removeResponsesUntilIndex(index: Long) = { + private def removeResponsesUntilIndex(index: Long, autoRemoved: Boolean = false) = { var i = index while (i >= 1 && responses.get(i).isDefined) { + val r = responses.get(i).get + cachedSizeUntilHighestConsumed.remove(r) + cachedSizeUntilLastProduced.remove(r) + if (autoRemoved) autoRemovedSize.add(r) responses.remove(i) i -= 1 } @@ -258,4 +316,26 @@ private[connect] class ExecuteResponseObserver[T <: MessageLite](val executeHold executePlanResponse.getResponseId } } + + /** + * Helper for counting statistics about cached responses. + */ + private case class CachedSize(var bytes: Long = 0L, var num: Long = 0L) { + var maxBytes: Long = 0L + var maxNum: Long = 0L + + def add(t: CachedStreamResponse[T]): Unit = { + bytes += t.serializedByteSize + if (bytes > maxBytes) maxBytes = bytes + num += 1 + if (num > maxNum) maxNum = num + } + + def remove(t: CachedStreamResponse[T]): Unit = { + bytes -= t.serializedByteSize + num -= 1 + } + + def max: CachedSize = CachedSize(maxBytes, maxNum) + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 930ccae5d4c76..62083d4892f78 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -222,7 +222,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread extends Thread { + private class ExecutionThread + extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { override def run(): Unit = { execute() } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 4eb90f9f1639a..105af0dc0baae 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.connect.service import java.util.UUID import scala.collection.JavaConverters._ +import scala.collection.mutable -import org.apache.spark.SparkSQLException +import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.common.ProtoUtils +import org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_ENABLED import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender, ExecuteResponseObserver, ExecuteThreadRunner} import org.apache.spark.util.SystemClock @@ -36,6 +38,8 @@ private[connect] class ExecuteHolder( val sessionHolder: SessionHolder) extends Logging { + val session = sessionHolder.session + val operationId = if (request.hasOperationId) { try { UUID.fromString(request.getOperationId).toString @@ -73,8 +77,11 @@ private[connect] class ExecuteHolder( * If execution is reattachable, it's life cycle is not limited to a single ExecutePlanRequest, * but can be reattached with ReattachExecute, and released with ReleaseExecute */ - val reattachable: Boolean = request.getRequestOptionsList.asScala.exists { option => - option.hasReattachOptions && option.getReattachOptions.getReattachable == true + val reattachable: Boolean = { + SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_ENABLED) && + request.getRequestOptionsList.asScala.exists { option => + option.hasReattachOptions && option.getReattachOptions.getReattachable == true + } } /** @@ -83,7 +90,12 @@ private[connect] class ExecuteHolder( */ var attached: Boolean = true - val session = sessionHolder.session + /** + * Threads that execute the ExecuteGrpcResponseSender and send the GRPC responses. + * + * TODO(SPARK-44625): Joining and cleaning up these threads during cleanup. + */ + val grpcSenderThreads: mutable.ArrayBuffer[Thread] = new mutable.ArrayBuffer[Thread]() val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] = new ExecuteResponseObserver[proto.ExecutePlanResponse](this) @@ -162,6 +174,7 @@ private[connect] class ExecuteHolder( def close(): Unit = { runner.interrupt() runner.join() + responseObserver.removeAll() eventsManager.postClosed() sessionHolder.removeExecuteHolder(operationId) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala index 0226b4e5ed3d5..9daf1e17b5e29 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala @@ -31,20 +31,10 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) val executeHolder = sessionHolder.createExecuteHolder(v) - try { - executeHolder.eventsManager.postStarted() - executeHolder.start() - val responseSender = - new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, responseObserver) - executeHolder.attachAndRunGrpcResponseSender(responseSender) - } finally { - if (!executeHolder.reattachable) { - // Non reattachable executions release here immediately. - executeHolder.close() - } else { - // Reattachable executions close release with ReleaseExecute RPC. - // TODO We mark in the ExecuteHolder that RPC detached. - } - } + executeHolder.eventsManager.postStarted() + executeHolder.start() + val responseSender = + new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, responseObserver) + executeHolder.attachAndRunGrpcResponseSender(responseSender) } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala index 362846a87b529..b70c82ab13792 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala @@ -44,20 +44,14 @@ class SparkConnectReattachExecuteHandler( messageParameters = Map.empty) } - try { - val responseSender = - new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, responseObserver) - if (v.hasLastResponseId) { - // start from response after lastResponseId - executeHolder.attachAndRunGrpcResponseSender(responseSender, v.getLastResponseId) - } else { - // start from the start of the stream. - executeHolder.attachAndRunGrpcResponseSender(responseSender) - } - } finally { - // Reattachable executions do not free the execution here, but client needs to call - // ReleaseExecute RPC. - // TODO We mark in the ExecuteHolder that RPC detached. + val responseSender = + new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, responseObserver) + if (v.hasLastResponseId) { + // start from response after lastResponseId + executeHolder.attachAndRunGrpcResponseSender(responseSender, v.getLastResponseId) + } else { + // start from the start of the stream. + executeHolder.attachAndRunGrpcResponseSender(responseSender) } } } From 3af2e77bd20cad3d9fe23cc0689eed29d5f5a537 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 8 Aug 2023 16:03:38 -0700 Subject: [PATCH 17/30] [SPARK-44725][DOCS] Document `spark.network.timeoutInterval` ### What changes were proposed in this pull request? This PR aims to document `spark.network.timeoutInterval` configuration. ### Why are the changes needed? Like `spark.network.timeout`, `spark.network.timeoutInterval` exists since Apache Spark 1.3.x. https://github.com/apache/spark/blob/418bba5ad6053449a141f3c9c31ed3ad998995b8/core/src/main/scala/org/apache/spark/internal/config/Network.scala#L48-L52 Since this is a user-facing configuration like the following, we had better document it. https://github.com/apache/spark/blob/418bba5ad6053449a141f3c9c31ed3ad998995b8/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala#L91-L93 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual because this is a doc-only change. Closes #42402 from dongjoon-hyun/SPARK-44725. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/configuration.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 99ef055955838..a70c049c87c6f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2397,6 +2397,14 @@ Apart from these, the following properties are also available, and may be useful 1.3.0 + + spark.network.timeoutInterval + 60s + + Interval for the driver to check and expire dead executors. + + 1.3.2 + spark.network.io.preferDirectBufs true From ea00ecd1f88beda2f97b03ccdea5d78dd178f3b6 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 9 Aug 2023 08:58:45 +0900 Subject: [PATCH 18/30] [SPARK-44722][CONNECT] ExecutePlanResponseReattachableIterator._call_iter: AttributeError: 'NoneType' object has no attribute 'message' ### What changes were proposed in this pull request? Tiny error during exception handling: status might in some cases be None. Make a similar change in scala client just to be defensive, because there it appears that it may never happen (when we catch `StatusRuntimeException`, `StatusProto.fromThrowable` should never return null). ### Why are the changes needed? Fix an error during reattach retry handling. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests that will be published in a followup. Closes #42397 from juliuszsompolski/SPARK-44722. Authored-by: Juliusz Sompolski Signed-off-by: Hyukjin Kwon --- .../client/ExecutePlanResponseReattachableIterator.scala | 6 ++---- python/pyspark/sql/connect/client/reattach.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index d412d9b577064..7a50801d8a6e5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -211,10 +211,8 @@ class ExecutePlanResponseReattachableIterator( iterFun(iter) } catch { case ex: StatusRuntimeException - if StatusProto - .fromThrowable(ex) - .getMessage - .contains("INVALID_HANDLE.OPERATION_NOT_FOUND") => + if Option(StatusProto.fromThrowable(ex)) + .exists(_.getMessage.contains("INVALID_HANDLE.OPERATION_NOT_FOUND")) => if (lastReturnedResponseId.isDefined) { throw new IllegalStateException( "OPERATION_NOT_FOUND on the server but responses were already received from it.", diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 70c7d126ff105..c5c45904c9baa 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -239,7 +239,7 @@ def _call_iter(self, iter_fun: Callable) -> Any: return iter_fun() except grpc.RpcError as e: status = rpc_status.from_call(cast(grpc.Call, e)) - if "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message: + if status is not None and "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message: if self._last_returned_response_id is not None: raise RuntimeError( "OPERATION_NOT_FOUND on the server but " From 69bd8358827f7c39b620d5cb22126fffb1267384 Mon Sep 17 00:00:00 2001 From: Amanda Liu Date: Wed, 9 Aug 2023 08:59:29 +0900 Subject: [PATCH 19/30] [SPARK-44665][PYTHON] Add support for pandas DataFrame assertDataFrameEqual ### What changes were proposed in this pull request? This PR adds support for pandas DataFrame in `assertDataFrameEqual` (in addition to pandas-on-Spark), while delaying all pandas imports until pandas environment dependency is verified. ### Why are the changes needed? The changes are needed to allow users to compare pandas DataFrame equality, while still ensuring compatibility for environments with no pandas installation. ### Does this PR introduce _any_ user-facing change? Yes, the PR affects the user-facing function `assertDataFrameEqual` ### How was this patch tested? Added tests to `python/pyspark/sql/tests/test_utils.py` and `python/pyspark/sql/tests/connect/test_utils.py` Closes #42332 from asl3/support-pandas-df. Authored-by: Amanda Liu Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/tests/test_utils.py | 16 +- python/pyspark/sql/tests/test_utils.py | 191 ++++++++++++++++-- python/pyspark/testing/pandasutils.py | 233 ++++++++++++++-------- python/pyspark/testing/utils.py | 88 ++++---- 4 files changed, 378 insertions(+), 150 deletions(-) diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index 0bb03dd8749da..60961dcf252ae 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -111,7 +111,7 @@ def test_validate_index_loc(self): with self.assertRaisesRegex(IndexError, err_msg): validate_index_loc(psidx, -4) - def test_assert_df_assertPandasOnSparkEqual(self): + def test_assert_df_assert_pandas_on_spark_equal(self): import pyspark.pandas as ps psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) @@ -120,15 +120,15 @@ def test_assert_df_assertPandasOnSparkEqual(self): assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=False) assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=True) - def test_assertPandasOnSparkEqual_ignoreOrder_default(self): + def test_assert_pandas_on_spark_equal_ignore_order(self): import pyspark.pandas as ps psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) psdf2 = ps.DataFrame({"a": [2, 1, 3], "b": [5, 4, 6], "c": [8, 7, 9]}) - assertPandasOnSparkEqual(psdf1, psdf2) + assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=False) - def test_assert_series_assertPandasOnSparkEqual(self): + def test_assert_series_assert_pandas_on_spark_equal(self): import pyspark.pandas as ps s1 = ps.Series([212.32, 100.0001]) @@ -136,7 +136,7 @@ def test_assert_series_assertPandasOnSparkEqual(self): assertPandasOnSparkEqual(s1, s2, checkExact=False) - def test_assert_index_assertPandasOnSparkEqual(self): + def test_assert_index_assert_pandas_on_spark_equal(self): import pyspark.pandas as ps s1 = ps.Index([212.300001, 100.000]) @@ -144,7 +144,7 @@ def test_assert_index_assertPandasOnSparkEqual(self): assertPandasOnSparkEqual(s1, s2, almost=True) - def test_assert_error_assertPandasOnSparkEqual(self): + def test_assert_error_assert_pandas_on_spark_equal(self): import pyspark.pandas as ps list1 = [10, 20, 30] @@ -165,13 +165,13 @@ def test_assert_error_assertPandasOnSparkEqual(self): }, ) - def test_assert_None_assertPandasOnSparkEqual(self): + def test_assert_None_assert_pandas_on_spark_equal(self): psdf1 = None psdf2 = None assertPandasOnSparkEqual(psdf1, psdf2) - def test_assert_empty_assertPandasOnSparkEqual(self): + def test_assert_empty_assert_pandas_on_spark_equal(self): import pyspark.pandas as ps psdf1 = ps.DataFrame() diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 93895465de7f7..e1b7f298d0a8b 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -15,6 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import unittest +import difflib +from itertools import zip_longest + from pyspark.sql.functions import sha2, to_timestamp from pyspark.errors import ( AnalysisException, @@ -23,7 +27,7 @@ IllegalArgumentException, SparkUpgradeException, ) -from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual, _context_diff +from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual, _context_diff, have_numpy from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.sql import Row import pyspark.sql.functions as F @@ -40,12 +44,7 @@ IntegerType, BooleanType, ) -from pyspark.sql.dataframe import DataFrame -import pyspark.pandas as ps - -import difflib -from typing import List, Union -from itertools import zip_longest +from pyspark.testing.sqlutils import have_pandas class UtilsTestsMixin: @@ -690,7 +689,7 @@ def test_assert_unequal_null_actual(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "actual", "actual_type": None, }, @@ -703,7 +702,7 @@ def test_assert_unequal_null_actual(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "actual", "actual_type": None, }, @@ -726,7 +725,7 @@ def test_assert_unequal_null_expected(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "expected", "actual_type": None, }, @@ -739,33 +738,195 @@ def test_assert_unequal_null_expected(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "expected", "actual_type": None, }, ) + @unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency") def test_assert_equal_exact_pandas_df(self): + import pandas as pd + import numpy as np + + df1 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)]), columns=["a", "b", "c"] + ) + df2 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)]), columns=["a", "b", "c"] + ) + + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + @unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency") + def test_assert_approx_equal_pandas_df(self): + import pandas as pd + import numpy as np + + # test that asserts close enough equality for pandas df + df1 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59)]), columns=["a", "b", "c"] + ) + df2 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59.0001)]), columns=["a", "b", "c"] + ) + + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + @unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency") + def test_assert_approx_equal_fail_exact_pandas_df(self): + import pandas as pd + import numpy as np + + # test that asserts close enough equality for pandas df + df1 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59)]), columns=["a", "b", "c"] + ) + df2 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 59.0001)]), columns=["a", "b", "c"] + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=False, rtol=0, atol=0) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": df1.to_string(), + "left_dtype": str(df1.dtypes), + "right": df2.to_string(), + "right_dtype": str(df2.dtypes), + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True, rtol=0, atol=0) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": df1.to_string(), + "left_dtype": str(df1.dtypes), + "right": df2.to_string(), + "right_dtype": str(df2.dtypes), + }, + ) + + @unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency") + def test_assert_unequal_pandas_df(self): + import pandas as pd + import numpy as np + + df1 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (6, 5, 4)]), columns=["a", "b", "c"] + ) + df2 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)]), columns=["a", "b", "c"] + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=False) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": df1.to_string(), + "left_dtype": str(df1.dtypes), + "right": df2.to_string(), + "right_dtype": str(df2.dtypes), + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": df1.to_string(), + "left_dtype": str(df1.dtypes), + "right": df2.to_string(), + "right_dtype": str(df2.dtypes), + }, + ) + + @unittest.skipIf(not have_pandas or not have_numpy, "no pandas or numpy dependency") + def test_assert_type_error_pandas_df(self): + import pyspark.pandas as ps + import pandas as pd + import numpy as np + + df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) + df2 = pd.DataFrame( + data=np.array([(1, 2, 3), (4, 5, 6), (6, 5, 4)]), columns=["a", "b", "c"] + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=False) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": df1.to_string(), + "left_dtype": str(df1.dtypes), + "right": df2.to_string(), + "right_dtype": str(df2.dtypes), + }, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": df1.to_string(), + "left_dtype": str(df1.dtypes), + "right": df2.to_string(), + "right_dtype": str(df2.dtypes), + }, + ) + + @unittest.skipIf(not have_pandas, "no pandas dependency") + def test_assert_equal_exact_pandas_on_spark_df(self): + import pyspark.pandas as ps + df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) df2 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) - def test_assert_equal_exact_pandas_df(self): + @unittest.skipIf(not have_pandas, "no pandas dependency") + def test_assert_equal_exact_pandas_on_spark_df(self): + import pyspark.pandas as ps + df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) df2 = ps.DataFrame(data=[30, 20, 10], columns=["Numbers"]) assertDataFrameEqual(df1, df2) - def test_assert_equal_approx_pandas_df(self): + @unittest.skipIf(not have_pandas, "no pandas dependency") + def test_assert_equal_approx_pandas_on_spark_df(self): + import pyspark.pandas as ps + df1 = ps.DataFrame(data=[10.0001, 20.32, 30.1], columns=["Numbers"]) df2 = ps.DataFrame(data=[10.0, 20.32, 30.1], columns=["Numbers"]) assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) + @unittest.skipIf(not have_pandas, "no pandas dependency") def test_assert_error_pandas_pyspark_df(self): + import pyspark.pandas as ps import pandas as pd df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) @@ -818,7 +979,7 @@ def test_assert_error_non_pyspark_df(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "actual", "actual_type": type(dict1), }, @@ -831,7 +992,7 @@ def test_assert_error_non_pyspark_df(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "actual", "actual_type": type(dict1), }, diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 39196873482b1..1122944b2c08a 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -23,7 +23,7 @@ from contextlib import contextmanager from distutils.version import LooseVersion import decimal -from typing import Union +from typing import Any, Union import pyspark.pandas as ps from pyspark.pandas.frame import DataFrame @@ -153,18 +153,29 @@ def _assert_pandas_equal( def _assert_pandas_almost_equal( - left: Union[pd.DataFrame, pd.Series, pd.Index], right: Union[pd.DataFrame, pd.Series, pd.Index] + left: Union[pd.DataFrame, pd.Series, pd.Index], + right: Union[pd.DataFrame, pd.Series, pd.Index], + rtol: float = 1e-5, + atol: float = 1e-8, ): """ This function checks if given pandas objects approximately same, which means the conditions below: - Both objects are nullable - - Compare floats rounding to the number of decimal places, 7 after - dropping missing values (NaN, NaT, None) + - Compare decimals and floats, where two values a and b are approximately equal + if they satisfy the following formula: + absolute(a - b) <= (atol + rtol * absolute(b)) + where rtol=1e-5 and atol=1e-8 by default """ - # following pandas convention, rtol=1e-5 and atol=1e-8 - rtol = 1e-5 - atol = 1e-8 + + def compare_vals_approx(val1, val2): + # compare vals for approximate equality + if isinstance(lval, (float, decimal.Decimal)) or isinstance(rval, (float, decimal.Decimal)): + if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): + return False + elif val1 != val2: + return False + return True if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): if left.shape != right.shape: @@ -200,19 +211,16 @@ def _assert_pandas_almost_equal( }, ) for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()): - if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( - isinstance(rval, float) or isinstance(rval, decimal.Decimal) - ): - if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): - raise PySparkAssertionError( - error_class="DIFFERENT_PANDAS_DATAFRAME", - message_parameters={ - "left": left.to_string(), - "left_dtype": str(left.dtypes), - "right": right.to_string(), - "right_dtype": str(right.dtypes), - }, - ) + if not compare_vals_approx(lval, rval): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, + ) if left.columns.names != right.columns.names: raise PySparkAssertionError( error_class="DIFFERENT_PANDAS_DATAFRAME", @@ -246,19 +254,16 @@ def _assert_pandas_almost_equal( }, ) for lval, rval in zip(left.dropna(), right.dropna()): - if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( - isinstance(rval, float) or isinstance(rval, decimal.Decimal) - ): - if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): - raise PySparkAssertionError( - error_class="DIFFERENT_PANDAS_SERIES", - message_parameters={ - "left": left.to_string(), - "left_dtype": str(left.dtype), - "right": right.to_string(), - "right_dtype": str(right.dtype), - }, - ) + if not compare_vals_approx(lval, rval): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_SERIES", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtype), + "right": right.to_string(), + "right_dtype": str(right.dtype), + }, + ) elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): if len(left) != len(right): raise PySparkAssertionError( @@ -271,19 +276,16 @@ def _assert_pandas_almost_equal( }, ) for lval, rval in zip(left, right): - if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( - isinstance(rval, float) or isinstance(rval, decimal.Decimal) - ): - if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): - raise PySparkAssertionError( - error_class="DIFFERENT_PANDAS_MULTIINDEX", - message_parameters={ - "left": left, - "left_dtype": str(left.dtype), - "right": right, - "right_dtype": str(right.dtype), - }, - ) + if not compare_vals_approx(lval, rval): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_MULTIINDEX", + message_parameters={ + "left": left, + "left_dtype": str(left.dtype), + "right": right, + "right_dtype": str(right.dtype), + }, + ) elif isinstance(left, pd.Index) and isinstance(right, pd.Index): if len(left) != len(right): raise PySparkAssertionError( @@ -307,21 +309,39 @@ def _assert_pandas_almost_equal( }, ) for lval, rval in zip(left.dropna(), right.dropna()): - if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( - isinstance(rval, float) or isinstance(rval, decimal.Decimal) - ): - if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): - raise PySparkAssertionError( - error_class="DIFFERENT_PANDAS_INDEX", - message_parameters={ - "left": left, - "left_dtype": str(left.dtype), - "right": right, - "right_dtype": str(right.dtype), - }, - ) + if not compare_vals_approx(lval, rval): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_INDEX", + message_parameters={ + "left": left, + "left_dtype": str(left.dtype), + "right": right, + "right_dtype": str(right.dtype), + }, + ) else: - raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + if not isinstance(left, (pd.DataFrame, pd.Series, pd.Index)): + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": f"{pd.DataFrame.__name__}, " + f"{pd.Series.__name__}, " + f"{pd.Index.__name__}, ", + "arg_name": "left", + "actual_type": type(left), + }, + ) + elif not isinstance(right, (pd.DataFrame, pd.Series, pd.Index)): + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": f"{pd.DataFrame.__name__}, " + f"{pd.Series.__name__}, " + f"{pd.Index.__name__}, ", + "arg_name": "right", + "actual_type": type(right), + }, + ) def assertPandasOnSparkEqual( @@ -329,20 +349,22 @@ def assertPandasOnSparkEqual( expected: Union[DataFrame, pd.DataFrame, Series, pd.Series, Index, pd.Index], checkExact: bool = True, almost: bool = False, - checkRowOrder: bool = False, + rtol: float = 1e-5, + atol: float = 1e-8, + checkRowOrder: bool = True, ): r""" - A util function to assert equality between actual (pandas-on-Spark DataFrame) and expected - (pandas-on-Spark or pandas DataFrame). + A util function to assert equality between actual (pandas-on-Spark object) and expected + (pandas-on-Spark or pandas object). .. versionadded:: 3.5.0 Parameters ---------- - actual: pyspark.pandas.frame.DataFrame - The DataFrame that is being compared or tested. - expected: pyspark.pandas.frame.DataFrame or pd.DataFrame - The expected DataFrame, for comparison with the actual result. + actual: pandas-on-Spark DataFrame, Series, or Index + The object that is being compared or tested. + expected: pandas-on-Spark or pandas DataFrame, Series, or Index + The expected object, for comparison with the actual result. checkExact: bool, optional A flag indicating whether to compare exact equality. If set to 'True' (default), the data is compared exactly. @@ -354,10 +376,16 @@ def assertPandasOnSparkEqual( (see documentation for more details). If set to 'False' (default), the data is compared exactly with `unittest`'s `assertEqual`. + rtol : float, optional + The relative tolerance, used in asserting almost equality for float values in actual + and expected. Set to 1e-5 by default. (See Notes) + atol : float, optional + The absolute tolerance, used in asserting almost equality for float values in actual + and expected. Set to 1e-8 by default. (See Notes) checkRowOrder : bool, optional A flag indicating whether the order of rows should be considered in the comparison. - If set to `False` (default), the row order is not taken into account. - If set to `True`, the order of rows is important and will be checked during comparison. + If set to `False`, the row order is not taken into account. + If set to `True` (default), the order of rows will be checked during comparison. (See Notes) Notes @@ -365,6 +393,11 @@ def assertPandasOnSparkEqual( For `checkRowOrder`, note that pandas-on-Spark DataFrame ordering is non-deterministic, unless explicitly sorted. + When `almost` is set to True, approximate equality will be asserted, where two values + a and b are approximately equal if they satisfy the following formula: + + ``absolute(a - b) <= (atol + rtol * absolute(b))``. + Examples -------- >>> import pyspark.pandas as ps @@ -407,8 +440,9 @@ def assertPandasOnSparkEqual( }, ) else: - actual = actual.to_pandas() - if not isinstance(expected, pd.DataFrame): + if not isinstance(actual, (pd.DataFrame, pd.Index, pd.Series)): + actual = actual.to_pandas() + if not isinstance(expected, (pd.DataFrame, pd.Index, pd.Series)): expected = expected.to_pandas() if not checkRowOrder: @@ -418,25 +452,40 @@ def assertPandasOnSparkEqual( expected = expected.sort_values(by=expected.columns[0], ignore_index=True) if almost: - _assert_pandas_almost_equal(actual, expected) + _assert_pandas_almost_equal(actual, expected, rtol=rtol, atol=atol) else: _assert_pandas_equal(actual, expected, checkExact=checkExact) class PandasOnSparkTestUtils: - def convert_str_to_lambda(self, func): + def convert_str_to_lambda(self, func: str): """ This function converts `func` str to lambda call """ return lambda x: getattr(x, func)() - def assertPandasEqual(self, left, right, check_exact=True): + def assertPandasEqual(self, left: Any, right: Any, check_exact: bool = True): _assert_pandas_equal(left, right, check_exact) - def assertPandasAlmostEqual(self, left, right): - _assert_pandas_almost_equal(left, right) - - def assert_eq(self, left, right, check_exact=True, almost=False): + def assertPandasAlmostEqual( + self, + left: Any, + right: Any, + rtol: float = 1e-5, + atol: float = 1e-8, + ): + _assert_pandas_almost_equal(left, right, rtol=rtol, atol=atol) + + def assert_eq( + self, + left: Any, + right: Any, + check_exact: bool = True, + almost: bool = False, + rtol: float = 1e-5, + atol: float = 1e-8, + check_row_order: bool = True, + ): """ Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame or Series, they are converted into pandas' and compared. @@ -444,17 +493,37 @@ def assert_eq(self, left, right, check_exact=True, almost=False): :param left: object to compare :param right: object to compare :param check_exact: if this is False, the comparison is done less precisely. - :param almost: if this is enabled, the comparison is delegated to `unittest`'s - `assertAlmostEqual`. See its documentation for more details. + :param almost: if this is enabled, the comparison asserts approximate equality + for float and decimal values, where two values a and b are approximately equal + if they satisfy the following formula: + absolute(a - b) <= (atol + rtol * absolute(b)) + :param rtol: The relative tolerance, used in asserting approximate equality for + float values. Set to 1e-5 by default. + :param atol: The absolute tolerance, used in asserting approximate equality for + float values in actual and expected. Set to 1e-8 by default. + :param check_row_order: A flag indicating whether the order of rows should be considered + in the comparison. If set to False, row order will be ignored. """ import pandas as pd from pandas.api.types import is_list_like + # for pandas-on-Spark DataFrames, allow choice to ignore row order + if isinstance(left, (ps.DataFrame, ps.Series, ps.Index)): + return assertPandasOnSparkEqual( + left, + right, + checkExact=check_exact, + almost=almost, + rtol=rtol, + atol=atol, + checkRowOrder=check_row_order, + ) + lobj = self._to_pandas(left) robj = self._to_pandas(right) if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)): if almost: - _assert_pandas_almost_equal(lobj, robj) + _assert_pandas_almost_equal(lobj, robj, rtol=rtol, atol=atol) else: _assert_pandas_equal(lobj, robj, checkExact=check_exact) elif is_list_like(lobj) and is_list_like(robj): @@ -470,7 +539,7 @@ def assert_eq(self, left, right, check_exact=True, almost=False): self.assertEqual(lobj, robj) @staticmethod - def _to_pandas(obj): + def _to_pandas(obj: Any): if isinstance(obj, (DataFrame, Series, Index)): return obj.to_pandas() else: diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 8e02803efe5cb..7dd723634e2f9 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -39,7 +39,6 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql import Row from pyspark.sql.types import StructType, AtomicType, StructField -import pyspark.pandas as ps have_scipy = False have_numpy = False @@ -360,9 +359,16 @@ def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any): ) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import pandas + import pyspark.pandas + + def assertDataFrameEqual( - actual: Union[DataFrame, ps.DataFrame, List[Row]], - expected: Union[DataFrame, ps.DataFrame, List[Row]], + actual: Union[DataFrame, "pandas.DataFrame", "pyspark.pandas.DataFrame", List[Row]], + expected: Union[DataFrame, "pandas.DataFrame", "pyspark.pandas.DataFrame", List[Row]], checkRowOrder: bool = False, rtol: float = 1e-5, atol: float = 1e-8, @@ -371,7 +377,7 @@ def assertDataFrameEqual( A util function to assert equality between `actual` and `expected` (DataFrames or lists of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`. - Supports Spark, Spark Connect, and pandas-on-Spark DataFrames. + Supports Spark, Spark Connect, pandas, and pandas-on-Spark DataFrames. For more information about pandas-on-Spark DataFrame equality, see the docs for `assertPandasOnSparkEqual`. @@ -379,9 +385,9 @@ def assertDataFrameEqual( Parameters ---------- - actual : DataFrame (Spark, Spark Connect, or pandas-on-Spark) or list of Rows + actual : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows The DataFrame that is being compared or tested. - expected : DataFrame (Spark, Spark Connect, or pandas-on-Spark) or list of Rows + expected : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows The expected result of the operation, for comparison with the actual result. checkRowOrder : bool, optional A flag indicating whether the order of rows should be considered in the comparison. @@ -446,16 +452,13 @@ def assertDataFrameEqual( Row(id='2', amount=3000.0) ! Row(id='3', amount=2003.0) """ - import pyspark.pandas as ps - from pyspark.testing.pandasutils import assertPandasOnSparkEqual - if actual is None and expected is None: return True elif actual is None: raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "actual", "actual_type": None, }, @@ -464,61 +467,56 @@ def assertDataFrameEqual( raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "expected", "actual_type": None, }, ) + has_pandas = False try: - # If Spark Connect dependencies are available, allow Spark Connect DataFrame - from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + # If pandas dependencies are available, allow pandas or pandas-on-Spark DataFrame + import pyspark.pandas as ps + import pandas as pd + from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + has_pandas = True + except ImportError: + # no pandas, so we won't call pandasutils functions + pass - if isinstance(actual, ps.DataFrame) or isinstance(expected, ps.DataFrame): - # handle pandas DataFrames - # assert approximate equality for float data - return assertPandasOnSparkEqual( - actual, expected, checkExact=False, checkRowOrder=checkRowOrder - ) - elif not isinstance(actual, (DataFrame, ConnectDataFrame, list)): - raise PySparkAssertionError( - error_class="INVALID_TYPE_DF_EQUALITY_ARG", - message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], - "arg_name": "actual", - "actual_type": type(actual), - }, - ) - elif not isinstance(expected, (DataFrame, ConnectDataFrame, list)): - raise PySparkAssertionError( - error_class="INVALID_TYPE_DF_EQUALITY_ARG", - message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], - "arg_name": "expected", - "actual_type": type(expected), - }, - ) - except Exception: - if isinstance(actual, ps.DataFrame) or isinstance(expected, ps.DataFrame): + if has_pandas: + if ( + isinstance(actual, pd.DataFrame) + or isinstance(expected, pd.DataFrame) + or isinstance(actual, ps.DataFrame) + or isinstance(expected, ps.DataFrame) + ): # handle pandas DataFrames # assert approximate equality for float data - return assertPandasOnSparkEqual( - actual, expected, checkExact=False, checkRowOrder=checkRowOrder + return PandasOnSparkTestUtils().assert_eq( + actual, expected, almost=True, rtol=rtol, atol=atol, check_row_order=checkRowOrder ) - elif not isinstance(actual, (DataFrame, list)): + + from pyspark.sql.utils import get_dataframe_class + + # if is_remote(), allow Connect DataFrame + SparkDataFrame = get_dataframe_class() + + if not isinstance(actual, (DataFrame, SparkDataFrame, list)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "actual", "actual_type": type(actual), }, ) - elif not isinstance(expected, (DataFrame, list)): + elif not isinstance(expected, (DataFrame, SparkDataFrame, list)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, ps.DataFrame, List[Row]], + "expected_type": "Union[DataFrame, ps.DataFrame, List[Row]]", "arg_name": "expected", "actual_type": type(expected), }, From 3920a4189ca30fdc9cab5948ab6045909858e8eb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 9 Aug 2023 09:23:29 +0900 Subject: [PATCH 20/30] [SPARK-44723][BUILD] Upgrade `gcs-connector` to 2.2.16 ### What changes were proposed in this pull request? This PR aims to upgrade `gcs-connector` to 2.2.16. ### Why are the changes needed? - https://github.com/GoogleCloudDataproc/hadoop-connectors/releases/tag/v2.2.16 (2023-06-30) - https://github.com/GoogleCloudDataproc/hadoop-connectors/releases/tag/v2.2.15 (2023-06-02) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs and do the manual tests. **BUILD** ``` dev/make-distribution.sh -Phadoop-cloud ``` **TEST** ``` $ export KEYFILE=your-credential-file.json $ export EMAIL=$(jq -r '.client_email' < $KEYFILE) $ export PRIVATE_KEY_ID=$(jq -r '.private_key_id' < $KEYFILE) $ export PRIVATE_KEY="$(jq -r '.private_key' < $KEYFILE)" $ bin/spark-shell \ -c spark.hadoop.fs.gs.auth.service.account.email=$EMAIL \ -c spark.hadoop.fs.gs.auth.service.account.private.key.id=$PRIVATE_KEY_ID \ -c spark.hadoop.fs.gs.auth.service.account.private.key="$PRIVATE_KEY" Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 23/08/08 10:43:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Spark context Web UI available at http://localhost:4040 Spark context available as 'sc' (master = local[*], app id = local-1691516610108). Spark session available as 'spark'. Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 4.0.0-SNAPSHOT /_/ Using Scala version 2.12.18 (OpenJDK 64-Bit Server VM, Java 1.8.0_312) Type in expressions to have them evaluated. Type :help for more information. scala> spark.read.text("gs://apache-spark-bucket/README.md").count() 23/08/08 10:43:46 WARN GhfsStorageStatistics: Detected potential high latency for operation op_get_file_status. latencyMs=823; previousMaxLatencyMs=0; operationCount=1; context=gs://apache-spark-bucket/README.md res0: Long = 124 scala> spark.read.orc("examples/src/main/resources/users.orc").write.mode("overwrite").orc("gs://apache-spark-bucket/users.orc") 23/08/08 10:43:59 WARN GhfsStorageStatistics: Detected potential high latency for operation op_delete. latencyMs=549; previousMaxLatencyMs=0; operationCount=1; context=gs://apache-spark-bucket/users.orc 23/08/08 10:43:59 WARN GhfsStorageStatistics: Detected potential high latency for operation op_mkdirs. latencyMs=440; previousMaxLatencyMs=0; operationCount=1; context=gs://apache-spark-bucket/users.orc/_temporary/0 23/08/08 10:44:04 WARN GhfsStorageStatistics: Detected potential high latency for operation op_delete. latencyMs=631; previousMaxLatencyMs=549; operationCount=2; context=gs://apache-spark-bucket/users.orc/_temporary 23/08/08 10:44:05 WARN GhfsStorageStatistics: Detected potential high latency for operation stream_write_close_operations. latencyMs=572; previousMaxLatencyMs=393; operationCount=2; context=gs://apache-spark-bucket/users.orc/_SUCCESS scala> scala> spark.read.orc("gs://apache-spark-bucket/users.orc").show() +------+--------------+----------------+ | name|favorite_color|favorite_numbers| +------+--------------+----------------+ |Alyssa| NULL| [3, 9, 15, 20]| | Ben| red| []| +------+--------------+----------------+ ``` Closes #42401 from dongjoon-hyun/SPARK-44723. Authored-by: Dongjoon Hyun Signed-off-by: Hyukjin Kwon --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 1f8c079a9bc8c..416753ab2010c 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -62,7 +62,7 @@ datasketches-memory/2.1.0//datasketches-memory-2.1.0.jar derby/10.14.2.0//derby-10.14.2.0.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar flatbuffers-java/1.12.0//flatbuffers-java-1.12.0.jar -gcs-connector/hadoop3-2.2.14/shaded/gcs-connector-hadoop3-2.2.14-shaded.jar +gcs-connector/hadoop3-2.2.16/shaded/gcs-connector-hadoop3-2.2.16-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.2.4//gson-2.2.4.jar guava/14.0.1//guava-14.0.1.jar diff --git a/pom.xml b/pom.xml index 76e3596edd430..624df0c314a0e 100644 --- a/pom.xml +++ b/pom.xml @@ -160,7 +160,7 @@ 1.11.655 0.12.8 - hadoop3-2.2.14 + hadoop3-2.2.16 4.5.14 4.4.16 From 66d8e6a3d83d1c686ea68165c656324a17c88a9a Mon Sep 17 00:00:00 2001 From: itholic Date: Wed, 9 Aug 2023 10:07:56 +0900 Subject: [PATCH 21/30] [SPARK-44695][PYTHON] Improve error message for `DataFrame.toDF` ### What changes were proposed in this pull request? This PR proposes to improve error message for `DataFrame.toDF` ### Why are the changes needed? The current error message is not helpful to solve the problem. ### Does this PR introduce _any_ user-facing change? Displaying more clear error message than before. **Before** ```python >>> df = spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) >>> cols = ['A', None] >>> df.toDF(*cols) Traceback (most recent call last): ... py4j.protocol.Py4JJavaError: An error occurred while calling o54.toDF. : org.apache.spark.SparkException: [INTERNAL_ERROR] The Spark SQL phase analysis failed with an internal error. You hit a bug in Spark or the Spark plugins you use. Please, report this bug to the corresponding communities or vendors, and provide the full stack trace. at org.apache.spark.SparkException$.internalError(SparkException.scala:98) at org.apache.spark.sql.execution.QueryExecution$.toInternalError(QueryExecution.scala:519) at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:531) at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:202) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:858) at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:201) at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:76) at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:74) at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:66) at org.apache.spark.sql.Dataset$.$anonfun$ofRows$1(Dataset.scala:92) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:858) at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:90) at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:4318) at org.apache.spark.sql.Dataset.select(Dataset.scala:1541) at org.apache.spark.sql.Dataset.toDF(Dataset.scala:539) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182) at py4j.ClientServerConnection.run(ClientServerConnection.java:106) at java.base/java.lang.Thread.run(Thread.java:829) Caused by: java.lang.NullPointerException at org.apache.spark.sql.catalyst.analysis.ColumnResolutionHelper.$anonfun$resolveLateralColumnAlias$2(ColumnResolutionHelper.scala:308) at scala.collection.immutable.List.map(List.scala:297) at org.apache.spark.sql.catalyst.analysis.ColumnResolutionHelper.resolveLateralColumnAlias(ColumnResolutionHelper.scala:305) at org.apache.spark.sql.catalyst.analysis.ColumnResolutionHelper.resolveLateralColumnAlias$(ColumnResolutionHelper.scala:260) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences$.resolveLateralColumnAlias(Analyzer.scala:1462) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences$$anonfun$apply$16.applyOrElse(Analyzer.scala:1602) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences$$anonfun$apply$16.applyOrElse(Analyzer.scala:1487) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUpWithPruning$3(AnalysisHelper.scala:138) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:76) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUpWithPruning$1(AnalysisHelper.scala:138) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:323) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUpWithPruning(AnalysisHelper.scala:134) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUpWithPruning$(AnalysisHelper.scala:130) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUpWithPruning(LogicalPlan.scala:32) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:111) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:110) at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:32) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences$.apply(Analyzer.scala:1487) at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences$.apply(Analyzer.scala:1462) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:222) at scala.collection.LinearSeqOptimized.foldLeft(LinearSeqOptimized.scala:126) at scala.collection.LinearSeqOptimized.foldLeft$(LinearSeqOptimized.scala:122) at scala.collection.immutable.List.foldLeft(List.scala:91) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:219) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:211) at scala.collection.immutable.List.foreach(List.scala:431) at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:211) at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:228) at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$execute$1(Analyzer.scala:224) at org.apache.spark.sql.catalyst.analysis.AnalysisContext$.withNewAnalysisContext(Analyzer.scala:173) at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:224) at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:188) at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:182) at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:88) at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:182) at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:209) at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:330) at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:208) at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:76) at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111) at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:202) at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:529) ... 24 more ``` **After** ```python >>> df = spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) >>> cols = ['A', None] >>> df.toDF(*cols) Traceback (most recent call last): ... raise PySparkTypeError( pyspark.errors.exceptions.base.PySparkTypeError: [NOT_LIST_OF_STR] Argument `cols` should be a list[str], got NoneType. ``` ### How was this patch tested? Add UT. Closes #42369 from itholic/improve_error_toDF. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/dataframe.py | 6 ++++++ python/pyspark/sql/dataframe.py | 6 ++++++ .../sql/tests/connect/test_parity_dataframe.py | 3 +++ python/pyspark/sql/tests/test_dataframe.py | 17 +++++++++++++++++ 4 files changed, 32 insertions(+) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 12e424b5ef137..14d9c2c9d05a8 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1732,6 +1732,12 @@ def to(self, schema: StructType) -> "DataFrame": to.__doc__ = PySparkDataFrame.to.__doc__ def toDF(self, *cols: str) -> "DataFrame": + for col_ in cols: + if not isinstance(col_, str): + raise PySparkTypeError( + error_class="NOT_LIST_OF_STR", + message_parameters={"arg_name": "cols", "arg_type": type(col_).__name__}, + ) return DataFrame.withPlan(plan.ToDF(self._plan, list(cols)), self._session) toDF.__doc__ = PySparkDataFrame.toDF.__doc__ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8e655dc3a88e5..f6fe17539c6e3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -5304,6 +5304,12 @@ def toDF(self, *cols: str) -> "DataFrame": | 16| Bob| +---+-----+ """ + for col in cols: + if not isinstance(col, str): + raise PySparkTypeError( + error_class="NOT_LIST_OF_STR", + message_parameters={"arg_name": "cols", "arg_type": type(col).__name__}, + ) jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sparkSession) diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index a74afc4d504bc..ccc0b997e8d4b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -84,6 +84,9 @@ def test_to_pandas_with_duplicated_column_names(self): def test_to_pandas_from_mixed_dataframe(self): self.check_to_pandas_from_mixed_dataframe() + def test_toDF_with_string(self): + super().test_toDF_with_string() + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 527a51cc239ed..33049233dee98 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1008,6 +1008,23 @@ def test_sample(self): IllegalArgumentException, lambda: self.spark.range(1).sample(-1.0).count() ) + def test_toDF_with_string(self): + df = self.spark.createDataFrame([("John", 30), ("Alice", 25), ("Bob", 28)]) + data = [("John", 30), ("Alice", 25), ("Bob", 28)] + + result = df.toDF("key", "value") + self.assertEqual(result.schema.simpleString(), "struct") + self.assertEqual(result.collect(), data) + + with self.assertRaises(PySparkTypeError) as pe: + df.toDF("key", None) + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OF_STR", + message_parameters={"arg_name": "cols", "arg_type": "NoneType"}, + ) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) From 9c9a4a844da30c2fccfffee3441344e1468573d6 Mon Sep 17 00:00:00 2001 From: itholic Date: Wed, 9 Aug 2023 10:08:31 +0900 Subject: [PATCH 22/30] [SPARK-43568][SPARK-43633][PS] Support `Categorical` APIs for pandas 2 ### What changes were proposed in this pull request? This PR proposes to support `Categorical` APIs for [pandas 2](https://pandas.pydata.org/docs/dev/whatsnew/v2.0.0.html), and match the behavior. ### Why are the changes needed? To support pandas API on Spark with pandas 2.0.0 and above. ### Does this PR introduce _any_ user-facing change? The behavior is matched with pandas 2.0.0 and above. e.g. ```diff >>> psser 0 1 1 2 2 3 3 1 4 2 5 3 Name: a, dtype: category Categories (3, int64): [1, 2, 3] >>> psser.cat.remove_categories([1, 2, 3]) 0 NaN 1 NaN 2 NaN 3 NaN 4 NaN 5 NaN Name: a, dtype: category - Categories (0, object): [] + Categories (0, int64): [] ``` ### How was this patch tested? Enabling the existing tests. Closes #42273 from itholic/pandas_categorical. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 1 + python/pyspark/pandas/categorical.py | 66 ++++++------------- .../pyspark/pandas/tests/test_categorical.py | 38 +---------- 3 files changed, 24 insertions(+), 81 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index d26f1cbbe0dc4..b029bcc649f9f 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -30,6 +30,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``DataFrame.mad`` has been removed from pandas API on Spark. * In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. * In Spark 4.0, ``na_sentinel`` parameter from ``Index.factorize`` and `Series.factorize`` has been removed from pandas API on Spark, use ``use_na_sentinel`` instead. +* In Spark 4.0, ``inplace`` parameter from ``Categorical.add_categories``, ``Categorical.remove_categories``, ``Categorical.set_categories``, ``Categorical.rename_categories``, ``Categorical.reorder_categories``, ``Categorical.as_ordered``, ``Categorical.as_unordered`` have been removed from pandas API on Spark. Upgrading from PySpark 3.3 to 3.4 diff --git a/python/pyspark/pandas/categorical.py b/python/pyspark/pandas/categorical.py index 36b11caf5b627..7043d1709ee81 100644 --- a/python/pyspark/pandas/categorical.py +++ b/python/pyspark/pandas/categorical.py @@ -15,7 +15,6 @@ # limitations under the License. # from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING, cast -import warnings import pandas as pd from pandas.api.types import ( # type: ignore[attr-defined] @@ -250,14 +249,11 @@ def add_categories(self, new_categories: Union[pd.Index, Any, List]) -> Optional ) return DataFrame(internal)._psser_for(self._data._column_label).copy() - def _set_ordered(self, *, ordered: bool, inplace: bool) -> Optional["ps.Series"]: + def _set_ordered(self, *, ordered: bool) -> Optional["ps.Series"]: from pyspark.pandas.frame import DataFrame if self.ordered == ordered: - if inplace: - return None - else: - return self._data.copy() + return self._data.copy() else: internal = self._data._psdf._internal.with_new_spark_column( self._data._column_label, @@ -266,24 +262,12 @@ def _set_ordered(self, *, ordered: bool, inplace: bool) -> Optional["ps.Series"] dtype=CategoricalDtype(categories=self.categories, ordered=ordered) ), ) - if inplace: - self._data._psdf._update_internal_frame(internal) - return None - else: - return DataFrame(internal)._psser_for(self._data._column_label).copy() + return DataFrame(internal)._psser_for(self._data._column_label).copy() - def as_ordered(self, inplace: bool = False) -> Optional["ps.Series"]: + def as_ordered(self) -> Optional["ps.Series"]: """ Set the Categorical to be ordered. - Parameters - ---------- - inplace : bool, default False - Whether or not to set the ordered attribute in-place or return - a copy of this categorical with ordered set to True. - - .. deprecated:: 3.4.0 - Returns ------- Series or None @@ -312,26 +296,12 @@ def as_ordered(self, inplace: bool = False) -> Optional["ps.Series"]: dtype: category Categories (3, object): ['a' < 'b' < 'c'] """ - if inplace: - warnings.warn( - "The `inplace` parameter in as_ordered is deprecated " - "and will be removed in a future version.", - FutureWarning, - ) - return self._set_ordered(ordered=True, inplace=inplace) + return self._set_ordered(ordered=True) - def as_unordered(self, inplace: bool = False) -> Optional["ps.Series"]: + def as_unordered(self) -> Optional["ps.Series"]: """ Set the Categorical to be unordered. - Parameters - ---------- - inplace : bool, default False - Whether or not to set the ordered attribute in-place or return - a copy of this categorical with ordered set to False. - - .. deprecated:: 3.4.0 - Returns ------- Series or None @@ -360,13 +330,7 @@ def as_unordered(self, inplace: bool = False) -> Optional["ps.Series"]: dtype: category Categories (3, object): ['a', 'b', 'c'] """ - if inplace: - warnings.warn( - "The `inplace` parameter in as_unordered is deprecated " - "and will be removed in a future version.", - FutureWarning, - ) - return self._set_ordered(ordered=False, inplace=inplace) + return self._set_ordered(ordered=False) def remove_categories(self, removals: Union[pd.Index, Any, List]) -> Optional["ps.Series"]: """ @@ -441,8 +405,13 @@ def remove_categories(self, removals: Union[pd.Index, Any, List]) -> Optional["p if len(categories) == 0: return self._data.copy() else: + data = [cat for cat in self.categories.sort_values() if cat not in categories] + if len(data) == 0: + # We should keep original dtype when even removing all categories. + data = pd.Index(data, dtype=self.categories.dtype) # type: ignore[assignment] dtype = CategoricalDtype( - [cat for cat in self.categories if cat not in categories], ordered=self.ordered + categories=data, + ordered=self.ordered, ) return self._data.astype(dtype) @@ -488,7 +457,14 @@ def remove_unused_categories(self) -> Optional["ps.Series"]: """ categories = set(self._data.drop_duplicates()._to_pandas()) removals = [cat for cat in self.categories if cat not in categories] - return self.remove_categories(removals=removals) + categories = [cat for cat in removals if cat is not None] # type: ignore[assignment] + if len(categories) == 0: + return self._data.copy() + else: + dtype = CategoricalDtype( + [cat for cat in self.categories if cat not in categories], ordered=self.ordered + ) + return self._data.astype(dtype) def rename_categories( self, new_categories: Union[list, dict, Callable] diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py index dae882a633d12..c45e063d6f466 100644 --- a/python/pyspark/pandas/tests/test_categorical.py +++ b/python/pyspark/pandas/tests/test_categorical.py @@ -65,21 +65,14 @@ def test_categorical_series(self): with self.assertRaisesRegex(ValueError, "Cannot call CategoricalAccessor on type int64"): ps.Series([1, 2, 3]).cat - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43566): Enable CategoricalTests.test_categories_setter for pandas 2.0.0.", - ) def test_categories_setter(self): pdf, psdf = self.df_pair pser = pdf.a psser = psdf.a - pser.cat.categories = ["z", "y", "x"] - psser.cat.categories = ["z", "y", "x"] - if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # Bug in pandas 1.3. dtype is not updated properly with `inplace` argument. - pser = pser.astype(CategoricalDtype(categories=["x", "y", "z"])) + pser = pser.cat.rename_categories(["z", "y", "x"]) + psser = psser.cat.rename_categories(["z", "y", "x"]) self.assert_eq(pser, psser) self.assert_eq(pdf, psdf) @@ -103,10 +96,6 @@ def test_add_categories(self): self.assertRaises(ValueError, lambda: psser.cat.add_categories(4)) self.assertRaises(ValueError, lambda: psser.cat.add_categories([5, 5])) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43605): Enable CategoricalTests.test_remove_categories for pandas 2.0.0.", - ) def test_remove_categories(self): pdf, psdf = self.df_pair @@ -168,10 +157,6 @@ def test_reorder_categories(self): self.assertRaises(TypeError, lambda: psser.cat.reorder_categories(1)) self.assertRaises(TypeError, lambda: psdf.b.cat.reorder_categories("abcd")) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43565): Enable CategoricalTests.test_as_ordered_unordered for pandas 2.0.0.", - ) def test_as_ordered_unordered(self): pdf, psdf = self.df_pair @@ -181,28 +166,9 @@ def test_as_ordered_unordered(self): # as_ordered self.assert_eq(pser.cat.as_ordered(), psser.cat.as_ordered()) - pser.cat.as_ordered(inplace=True) - psser.cat.as_ordered(inplace=True) - if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # Bug in pandas 1.3. dtype is not updated properly with `inplace` argument. - pser = pser.astype(CategoricalDtype(categories=[1, 2, 3], ordered=True)) - - self.assert_eq(pser, psser) - self.assert_eq(pdf, psdf) - # as_unordered self.assert_eq(pser.cat.as_unordered(), psser.cat.as_unordered()) - pser.cat.as_unordered(inplace=True) - psser.cat.as_unordered(inplace=True) - if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # Bug in pandas 1.3. dtype is not updated properly with `inplace` argument. - pser = pser.astype(CategoricalDtype(categories=[1, 2, 3], ordered=False)) - pdf.a = pser - - self.assert_eq(pser, psser) - self.assert_eq(pdf, psdf) - def test_astype(self): pser = pd.Series(["a", "b", "c"]) psser = ps.from_pandas(pser) From e05959e1900cc687f61b794da47a1516d9baf66b Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 9 Aug 2023 11:03:40 +0900 Subject: [PATCH 23/30] [SPARK-44717][PYTHON][PS] Respect TimestampNTZ in resampling ### What changes were proposed in this pull request? This PR proposes to respect `TimestampNTZ` type in resampling at pandas API on Spark. ### Why are the changes needed? It still operates as if the timestamps are `TIMESTAMP_LTZ` even when `spark.sql.timestampType` is set to `TIMESTAMP_NTZ`, which is unexpected. ### Does this PR introduce _any_ user-facing change? This fixes a bug so end users can use exactly same behaviour with pandas with `TimestampNTZType` - pandas does not respect the local timezone with DST. While we might need to follow this even for `TimestampType`, this PR does not address the case as it might be controversial. ### How was this patch tested? Unittest was added. Closes #42392 from HyukjinKwon/SPARK-44717. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/frame.py | 4 +- python/pyspark/pandas/resample.py | 43 +++++++++++------ .../tests/connect/test_parity_resample.py | 12 ++++- python/pyspark/pandas/tests/test_resample.py | 47 +++++++++++++++++++ 4 files changed, 88 insertions(+), 18 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 72d4a88b69203..65c43eb7cf42c 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -13155,7 +13155,9 @@ def resample( if on is None and not isinstance(self.index, DatetimeIndex): raise NotImplementedError("resample currently works only for DatetimeIndex") - if on is not None and not isinstance(as_spark_type(on.dtype), TimestampType): + if on is not None and not isinstance( + as_spark_type(on.dtype), (TimestampType, TimestampNTZType) + ): raise NotImplementedError("`on` currently works only for TimestampType") agg_columns: List[ps.Series] = [] diff --git a/python/pyspark/pandas/resample.py b/python/pyspark/pandas/resample.py index c6c6019c07e6a..30f8c9d31695e 100644 --- a/python/pyspark/pandas/resample.py +++ b/python/pyspark/pandas/resample.py @@ -46,7 +46,8 @@ from pyspark.sql.types import ( NumericType, StructField, - TimestampType, + TimestampNTZType, + DataType, ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. @@ -130,6 +131,13 @@ def _resamplekey_scol(self) -> Column: else: return self._resamplekey.spark.column + @property + def _resamplekey_type(self) -> DataType: + if self._resamplekey is None: + return self._psdf.index.spark.data_type + else: + return self._resamplekey.spark.data_type + @property def _agg_columns_scols(self) -> List[Column]: return [s.spark.column for s in self._agg_columns] @@ -154,7 +162,8 @@ def get_make_interval( # type: ignore[return] col = col._jc if isinstance(col, Column) else F.lit(col)._jc return sql_utils.makeInterval(unit, col) - def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: + def _bin_timestamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: + key_type = self._resamplekey_type origin_scol = F.lit(origin) (rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n")) left_closed, right_closed = (self._closed == "left", self._closed == "right") @@ -188,7 +197,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: F.year(ts_scol) - (mod - n) ) - return F.to_timestamp( + ret = F.to_timestamp( F.make_date( F.when(edge_cond, edge_label).otherwise(non_edge_label), F.lit(12), F.lit(31) ) @@ -227,7 +236,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: truncated_ts_scol - self.get_make_interval("MONTH", mod - n) ) - return F.to_timestamp( + ret = F.to_timestamp( F.last_day(F.when(edge_cond, edge_label).otherwise(non_edge_label)) ) @@ -242,15 +251,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: ) if left_closed and left_labeled: - return F.date_trunc("DAY", ts_scol) + ret = F.date_trunc("DAY", ts_scol) elif left_closed and right_labeled: - return F.date_trunc("DAY", F.date_add(ts_scol, 1)) + ret = F.date_trunc("DAY", F.date_add(ts_scol, 1)) elif right_closed and left_labeled: - return F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise( + ret = F.when(edge_cond, F.date_trunc("DAY", F.date_sub(ts_scol, 1))).otherwise( F.date_trunc("DAY", ts_scol) ) else: - return F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise( + ret = F.when(edge_cond, F.date_trunc("DAY", ts_scol)).otherwise( F.date_trunc("DAY", F.date_add(ts_scol, 1)) ) @@ -272,13 +281,15 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: else: non_edge_label = F.date_sub(truncated_ts_scol, mod - n) - return F.when(edge_cond, edge_label).otherwise(non_edge_label) + ret = F.when(edge_cond, edge_label).otherwise(non_edge_label) elif rule_code in ["H", "T", "S"]: unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"} unit_str = unit_mapping[rule_code] truncated_ts_scol = F.date_trunc(unit_str, ts_scol) + if isinstance(key_type, TimestampNTZType): + truncated_ts_scol = F.to_timestamp_ntz(truncated_ts_scol) diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol) mod = F.lit(0) if n == 1 else (diff % F.lit(n)) @@ -307,11 +318,16 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column: truncated_ts_scol + self.get_make_interval(unit_str, n), ).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n)) - return F.when(edge_cond, edge_label).otherwise(non_edge_label) + ret = F.when(edge_cond, edge_label).otherwise(non_edge_label) else: raise ValueError("Got the unexpected unit {}".format(rule_code)) + if isinstance(key_type, TimestampNTZType): + return F.to_timestamp_ntz(ret) + else: + return ret + def _downsample(self, f: str) -> DataFrame: """ Downsample the defined function. @@ -374,12 +390,9 @@ def _downsample(self, f: str) -> DataFrame: bin_col_label = verify_temp_column_name(self._psdf, bin_col_name) bin_col_field = InternalField( dtype=np.dtype("datetime64[ns]"), - struct_field=StructField(bin_col_name, TimestampType(), True), - ) - bin_scol = self._bin_time_stamp( - ts_origin, - self._resamplekey_scol, + struct_field=StructField(bin_col_name, self._resamplekey_type, True), ) + bin_scol = self._bin_timestamp(ts_origin, self._resamplekey_scol) agg_columns = [ psser for psser in self._agg_columns if (isinstance(psser.spark.data_type, NumericType)) diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py b/python/pyspark/pandas/tests/connect/test_parity_resample.py index e5957cc9b4a29..d5c901f113a05 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_resample.py +++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py @@ -16,17 +16,25 @@ # import unittest -from pyspark.pandas.tests.test_resample import ResampleTestsMixin +from pyspark.pandas.tests.test_resample import ResampleTestsMixin, ResampleWithTimezoneMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils -class ResampleTestsParityMixin( +class ResampleParityTests( ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase ): pass +class ResampleWithTimezoneTests( + ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase +): + @unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client") + def test_series_resample_with_timezone(self): + super().test_series_resample_with_timezone() + + if __name__ == "__main__": from pyspark.pandas.tests.connect.test_parity_resample import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/test_resample.py b/python/pyspark/pandas/tests/test_resample.py index 0650fc40448e3..4061402590767 100644 --- a/python/pyspark/pandas/tests/test_resample.py +++ b/python/pyspark/pandas/tests/test_resample.py @@ -19,6 +19,8 @@ import unittest import inspect import datetime +import os + import numpy as np import pandas as pd @@ -283,10 +285,55 @@ def test_resample_on(self): ) +class ResampleWithTimezoneMixin: + timezone = None + + @classmethod + def setUpClass(cls): + cls.timezone = os.environ.get("TZ", None) + os.environ["TZ"] = "America/New_York" + super(ResampleWithTimezoneMixin, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + super(ResampleWithTimezoneMixin, cls).tearDownClass() + if cls.timezone is not None: + os.environ["TZ"] = cls.timezone + + @property + def pdf(self): + np.random.seed(22) + index = pd.date_range(start="2011-01-02", end="2022-05-01", freq="1D") + return pd.DataFrame(np.random.rand(len(index), 2), index=index, columns=list("AB")) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_series_resample_with_timezone(self): + with self.sql_conf( + { + "spark.sql.session.timeZone": "Asia/Seoul", + "spark.sql.timestampType": "TIMESTAMP_NTZ", + } + ): + p_resample = self.pdf.resample(rule="1001H", closed="right", label="right") + ps_resample = self.psdf.resample(rule="1001H", closed="right", label="right") + self.assert_eq( + p_resample.sum().sort_index(), + ps_resample.sum().sort_index(), + almost=True, + ) + + class ResampleTests(ResampleTestsMixin, PandasOnSparkTestCase, TestUtils): pass +class ResampleWithTimezoneTests(ResampleWithTimezoneMixin, PandasOnSparkTestCase, TestUtils): + pass + + if __name__ == "__main__": from pyspark.pandas.tests.test_resample import * # noqa: F401 From eccc045a250818db5e0cfe7d00ac6dcae1ac3d7e Mon Sep 17 00:00:00 2001 From: mox692 Date: Wed, 9 Aug 2023 11:04:42 +0900 Subject: [PATCH 24/30] [MINOR][DOC] Fixed deprecated procedure syntax ### What changes were proposed in this pull request? The scala sample code in the quick-start chapter was using deprecated syntax, so this PR has fixed it. ### Why are the changes needed? procedure syntax (SI-7605) is now deprecated, and some editors or IDEs warn against this code. https://github.com/scala/bug/issues/7605 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `SKIP_API=1 bundle exec jekyll build` on local. Closes #42400 from mox692/deprecate_procedure_syntax. Authored-by: mox692 Signed-off-by: Hyukjin Kwon --- docs/quick-start.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/quick-start.md b/docs/quick-start.md index 91b23851f721e..8deb10e12cb25 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -313,7 +313,7 @@ named `SimpleApp.scala`: import org.apache.spark.sql.SparkSession object SimpleApp { - def main(args: Array[String]) { + def main(args: Array[String]): Unit = { val logFile = "YOUR_SPARK_HOME/README.md" // Should be some file on your system val spark = SparkSession.builder.appName("Simple Application").getOrCreate() val logData = spark.read.textFile(logFile).cache() From 3cec5a4c7d8cf86141b16236925e54886a807a42 Mon Sep 17 00:00:00 2001 From: itholic Date: Wed, 9 Aug 2023 14:10:00 +0900 Subject: [PATCH 25/30] [SPARK-43709][PS] Remove `closed` parameter from `ps.date_range` & enable test ### What changes were proposed in this pull request? This PR proposes to remove `closed` parameter from `ps.date_range` & enable test. See https://github.com/pandas-dev/pandas/issues/40245 more detail. ### Why are the changes needed? To support pandas 2.0.0 and above. ### Does this PR introduce _any_ user-facing change? `closed` parameter will no longer available from `ps.date_range` API. ### How was this patch tested? Enabling the existing UT. Closes #42389 from itholic/closed_removing. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 1 + python/pyspark/pandas/namespace.py | 38 +------------------ python/pyspark/pandas/tests/test_namespace.py | 14 ------- 3 files changed, 2 insertions(+), 51 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index b029bcc649f9f..1b247d4622787 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -31,6 +31,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. * In Spark 4.0, ``na_sentinel`` parameter from ``Index.factorize`` and `Series.factorize`` has been removed from pandas API on Spark, use ``use_na_sentinel`` instead. * In Spark 4.0, ``inplace`` parameter from ``Categorical.add_categories``, ``Categorical.remove_categories``, ``Categorical.set_categories``, ``Categorical.rename_categories``, ``Categorical.reorder_categories``, ``Categorical.as_ordered``, ``Categorical.as_unordered`` have been removed from pandas API on Spark. +* In Spark 4.0, ``closed`` parameter from ``ps.date_range`` has been removed from pandas API on Spark. Upgrading from PySpark 3.3 to 3.4 diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 5ffec6bedb988..ba93e5a3ee506 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -1751,7 +1751,7 @@ def pandas_to_datetime( ) -# TODO(SPARK-42621): Add `inclusive` parameter and replace `closed`. +# TODO(SPARK-42621): Add `inclusive` parameter. # See https://github.com/pandas-dev/pandas/issues/40245 def date_range( start: Union[str, Any] = None, @@ -1761,7 +1761,6 @@ def date_range( tz: Optional[Union[str, tzinfo]] = None, normalize: bool = False, name: Optional[str] = None, - closed: Optional[str] = None, **kwargs: Any, ) -> DatetimeIndex: """ @@ -1785,12 +1784,6 @@ def date_range( Normalize start/end dates to midnight before generating date range. name : str, default None Name of the resulting DatetimeIndex. - closed : {None, 'left', 'right'}, optional - Make the interval closed with respect to the given frequency to - the 'left', 'right', or both sides (None, the default). - - .. deprecated:: 3.4.0 - **kwargs For compatibility. Has no effect on the result. @@ -1874,37 +1867,9 @@ def date_range( DatetimeIndex(['2018-01-31', '2018-04-30', '2018-07-31', '2018-10-31', '2019-01-31'], dtype='datetime64[ns]', freq=None) - - `closed` controls whether to include `start` and `end` that are on the - boundary. The default includes boundary points on either end. - - >>> ps.date_range( - ... start='2017-01-01', end='2017-01-04', closed=None - ... ) # doctest: +SKIP - DatetimeIndex(['2017-01-01', '2017-01-02', '2017-01-03', '2017-01-04'], - dtype='datetime64[ns]', freq=None) - - Use ``closed='left'`` to exclude `end` if it falls on the boundary. - - >>> ps.date_range( - ... start='2017-01-01', end='2017-01-04', closed='left' - ... ) # doctest: +SKIP - DatetimeIndex(['2017-01-01', '2017-01-02', '2017-01-03'], dtype='datetime64[ns]', freq=None) - - Use ``closed='right'`` to exclude `start` if it falls on the boundary. - - >>> ps.date_range( - ... start='2017-01-01', end='2017-01-04', closed='right' - ... ) # doctest: +SKIP - DatetimeIndex(['2017-01-02', '2017-01-03', '2017-01-04'], dtype='datetime64[ns]', freq=None) """ assert freq not in ["N", "ns"], "nanoseconds is not supported" assert tz is None, "Localized DatetimeIndex is not supported" - if closed is not None: - warnings.warn( - "Argument `closed` is deprecated in 3.4.0 and will be removed in 4.0.0.", - FutureWarning, - ) return cast( DatetimeIndex, @@ -1917,7 +1882,6 @@ def date_range( tz=tz, normalize=normalize, name=name, - closed=closed, **kwargs, ) ), diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py index 64c58a702390c..d1d1e1af9354d 100644 --- a/python/pyspark/pandas/tests/test_namespace.py +++ b/python/pyspark/pandas/tests/test_namespace.py @@ -190,10 +190,6 @@ def test_to_datetime(self): self.assert_eq(pd.to_datetime(pdf), ps.to_datetime(psdf)) self.assert_eq(pd.to_datetime(dict_from_pdf), ps.to_datetime(dict_from_pdf)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43709): Enable NamespaceTests.test_date_range for pandas 2.0.0.", - ) def test_date_range(self): self.assert_eq( ps.date_range(start="1/1/2018", end="1/08/2018"), @@ -225,16 +221,6 @@ def test_date_range(self): pd.date_range(start="1/1/2018", periods=5, freq=pd.offsets.MonthEnd(3)), ) - self.assert_eq( - ps.date_range(start="2017-01-01", end="2017-01-04", closed="left"), - pd.date_range(start="2017-01-01", end="2017-01-04", closed="left"), - ) - - self.assert_eq( - ps.date_range(start="2017-01-01", end="2017-01-04", closed="right"), - pd.date_range(start="2017-01-01", end="2017-01-04", closed="right"), - ) - self.assertRaises( AssertionError, lambda: ps.date_range(start="1/1/2018", periods=5, tz="Asia/Tokyo") ) From e584ed4ad96a0f0573455511d7be0e9b2afbeb96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=99=E8=89=AF?= Date: Wed, 9 Aug 2023 13:46:57 +0800 Subject: [PATCH 26/30] [SPARK-44581][YARN] Fix the bug that ShutdownHookManager gets wrong UGI from SecurityManager of ApplicationMaster MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? I make the SecurityManager instance a lazy value ### Why are the changes needed? fix the bug in issue [SPARK-44581](https://issues.apache.org/jira/browse/SPARK-44581) **Bug:** In spark3.2 it throws the org.apache.hadoop.security.AccessControlException, but in spark2.4 this hook does not throw exception. I rebuild the hadoop-client-api.jar, and add some debug log before the hadoop shutdown hook is created, and rebuild the spark-yarn.jar to add some debug log when creating the spark shutdown hook manager, here is the screenshot of the log: ![image](https://github.com/apache/spark/assets/62563545/ea338db3-646c-432c-bf16-1f445adc2ad9) We can see from the screenshot, the ShutdownHookManager is initialized before the ApplicationManager create a new ugi. **Reason** The main cause is that ShutdownHook thread is created before we create the ugi in ApplicationMaster. When we set the config key "hadoop.security.credential.provider.path", the ApplicationMaster will try to get a filesystem when generating SSLOptions, and when initialize the filesystem during which it will generate a new thread whose ugi is inherited from the current process (yarn). After this, it will generate a new ugi (SPARK_USER) in ApplicationMaster and execute the doAs() function. Here is the chain of the call: ApplicationMaster.(ApplicationMaster.scala:83) -> org.apache.spark.SecurityManager.(SecurityManager.scala:98) -> org.apache.spark.SSLOptions$.parse(SSLOptions.scala:188) -> org.apache.hadoop.conf.Configuration.getPassword(Configuration.java:2353) -> org.apache.hadoop.conf.Configuration.getPasswordFromCredentialProviders(Configuration.java:2434) -> org.apache.hadoop.security.alias.CredentialProviderFactory.getProviders(CredentialProviderFactory.java:82) ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? I didn't add new UnitTest for this, but I rebuild the package, and runs a program in my cluster, and turns out that the user when I delete the staging file turns to be the same with the SPARK_USER. Closes #42405 from liangyu-1/SPARK-44581. Authored-by: 余良 Signed-off-by: Kent Yao --- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 0149a3f62175b..4fa7b66c9e5a4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -79,7 +79,7 @@ private[spark] class ApplicationMaster( private val isClusterMode = args.userClass != null - private val securityMgr = new SecurityManager(sparkConf) + private lazy val securityMgr = new SecurityManager(sparkConf) private var metricsSystem: Option[MetricsSystem] = None From bdee74b43451cb684d72a5829e064f676f58aed1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 8 Aug 2023 22:58:22 -0700 Subject: [PATCH 27/30] [SPARK-44726][CORE] Improve `HeartbeatReceiver` config validation error message ### What changes were proposed in this pull request? This PR aims to improve `HeartbeatReceiver` to give a clear directional message for Apache Spark 4.0. ### Why are the changes needed? Currently, when we set improper `spark.network.timeout` value, the error message is misleading because it complains about the relationship between `spark.network.timeoutInterval` and `spark.storage.blockManagerHeartbeatTimeoutMs` which the users never have in their Spark jobs. ``` $ bin/spark-shell -c spark.network.timeout=30s ... java.lang.IllegalArgumentException: requirement failed: spark.network.timeoutInterval should be less than or equal to spark.storage.blockManagerHeartbeatTimeoutMs. ``` ### Does this PR introduce _any_ user-facing change? No. This PR gives more direct messages based on the users' configuration. ### How was this patch tested? Pass the CIs with the newly added test cases. Closes #42403 from dongjoon-hyun/SPARK-44726. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/HeartbeatReceiver.scala | 12 ++++++-- .../apache/spark/HeartbeatReceiverSuite.scala | 29 ++++++++++++++++++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 825d9ce77947e..5999040894ae5 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -88,9 +88,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) private val executorHeartbeatIntervalMs = sc.conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL) - require(checkTimeoutIntervalMs <= executorTimeoutMs, - s"${Network.NETWORK_TIMEOUT_INTERVAL.key} should be less than or " + - s"equal to ${config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key}.") + if (sc.conf.get(config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT).isEmpty) { + require(checkTimeoutIntervalMs <= executorTimeoutMs, + s"${Network.NETWORK_TIMEOUT_INTERVAL.key} should be less than or " + + s"equal to ${Network.NETWORK_TIMEOUT.key}.") + } else { + require(checkTimeoutIntervalMs <= executorTimeoutMs, + s"${Network.NETWORK_TIMEOUT_INTERVAL.key} should be less than or " + + s"equal to ${config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key}.") + } require(executorHeartbeatIntervalMs <= executorTimeoutMs, s"${config.EXECUTOR_HEARTBEAT_INTERVAL.key} should be less than or " + s"equal to ${config.STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key}") diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index ee0a57736921c..a8351322e01a8 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -29,7 +29,8 @@ import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} -import org.apache.spark.internal.config.DYN_ALLOCATION_TESTING +import org.apache.spark.internal.config.{DYN_ALLOCATION_TESTING, STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT} +import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.resource.{ResourceProfile, ResourceProfileManager} import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ @@ -237,6 +238,32 @@ class HeartbeatReceiverSuite } } + test("SPARK-44726: Show spark.network.timeout config error message") { + sc.stop() + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set(NETWORK_TIMEOUT.key, "30s") + val m = intercept[IllegalArgumentException] { + new SparkContext(conf) + }.getMessage + assert(m.contains("spark.network.timeoutInterval should be less than or equal to " + + NETWORK_TIMEOUT.key)) + } + + test("SPARK-44726: Show spark.storage.blockManagerHeartbeatTimeoutMs error message") { + sc.stop() + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set(STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key, "30s") + val m = intercept[IllegalArgumentException] { + new SparkContext(conf) + }.getMessage + assert(m.contains("spark.network.timeoutInterval should be less than or equal to " + + STORAGE_BLOCKMANAGER_HEARTBEAT_TIMEOUT.key)) + } + /** Manually send a heartbeat and return the response. */ private def triggerHeartbeat( executorId: String, From f9058d69b2af774e78677ac4cad55c7c91eb42ae Mon Sep 17 00:00:00 2001 From: Jack Chen Date: Wed, 9 Aug 2023 15:10:51 +0800 Subject: [PATCH 28/30] [SPARK-44551][SQL] Fix behavior of null IN (empty list) in expression execution ### What changes were proposed in this pull request? `null IN (empty list)` incorrectly evaluates to null, when it should evaluate to false. (The reason it should be false is because a IN (b1, b2) is defined as a = b1 OR a = b2, and an empty IN list is treated as an empty OR which is false. This is specified by ANSI SQL.) Many places in Spark execution (In, InSet, InSubquery) and optimization (OptimizeIn, NullPropagation) implemented this wrong behavior. This is a longstanding correctness issue which has existed since null support for IN expressions was first added to Spark. This PR fixes Spark execution (In, InSet, InSubquery). See previous PR https://github.com/apache/spark/pull/42007 for optimization fixes. The behavior is under a flag, which will be available to revert to the legacy behavior if needed. This flag is set to disable the new behavior until all of the fix PRs are complete. See [this doc](https://docs.google.com/document/d/1k8AY8oyT-GI04SnP7eXttPDnDj-Ek-c3luF2zL6DPNU/edit) for more information. ### Why are the changes needed? Fix wrong SQL semantics ### Does this PR introduce _any_ user-facing change? Not yet, but will fix wrong SQL semantics when enabled ### How was this patch tested? Unit tests PredicateSuite and tests added in previous PR https://github.com/apache/spark/pull/42007 Closes #42163 from jchen5/null-in-empty-exec. Authored-by: Jack Chen Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/predicates.scala | 196 +++++++++++------- .../sql/catalyst/optimizer/expressions.scala | 4 +- .../catalyst/expressions/PredicateSuite.scala | 35 ++-- .../in-subquery/in-null-semantics.sql | 3 +- .../in-subquery/in-null-semantics.sql.out | 4 +- .../org/apache/spark/sql/EmptyInSuite.scala | 3 +- 6 files changed, 144 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ee2ba7c73d1f3..31b872e04ce72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -468,6 +468,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def foldable: Boolean = children.forall(_.foldable) final override val nodePatterns: Seq[TreePattern] = Seq(IN) + private val legacyNullInEmptyBehavior = + SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR) override lazy val canonicalized: Expression = { val basic = withNewChildren(children.map(_.canonicalized)).asInstanceOf[In] @@ -481,88 +483,104 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { - val evaluatedValue = value.eval(input) - if (evaluatedValue == null) { - null + if (list.isEmpty && !legacyNullInEmptyBehavior) { + // IN (empty list) is always false under current behavior. + // Under legacy behavior it's null if the left side is null, otherwise false (SPARK-44550). + false } else { - var hasNull = false - list.foreach { e => - val v = e.eval(input) - if (v == null) { - hasNull = true - } else if (ordering.equiv(v, evaluatedValue)) { - return true - } - } - if (hasNull) { + val evaluatedValue = value.eval(input) + if (evaluatedValue == null) { null } else { - false + var hasNull = false + list.foreach { e => + val v = e.eval(input) + if (v == null) { + hasNull = true + } else if (ordering.equiv(v, evaluatedValue)) { + return true + } + } + if (hasNull) { + null + } else { + false + } } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = CodeGenerator.javaType(value.dataType) - val valueGen = value.genCode(ctx) - val listGen = list.map(_.genCode(ctx)) - // inTmpResult has 3 possible values: - // -1 means no matches found and there is at least one value in the list evaluated to null - val HAS_NULL = -1 - // 0 means no matches found and all values in the list are not null - val NOT_MATCHED = 0 - // 1 means one value in the list is matched - val MATCHED = 1 - val tmpResult = ctx.freshName("inTmpResult") - val valueArg = ctx.freshName("valueArg") - // All the blocks are meant to be inside a do { ... } while (false); loop. - // The evaluation of variables can be stopped when we find a matching value. - val listCode = listGen.map(x => - s""" - |${x.code} - |if (${x.isNull}) { - | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; - |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; - | continue; - |} + if (list.isEmpty && !legacyNullInEmptyBehavior) { + // IN (empty list) is always false under current behavior. + // Under legacy behavior it's null if the left side is null, otherwise false (SPARK-44550). + ev.copy(code = + code""" + |final boolean ${ev.isNull} = false; + |final boolean ${ev.value} = false; """.stripMargin) - - val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = listCode, - funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, - returnType = CodeGenerator.JAVA_BYTE, - makeSplitFunction = body => - s""" - |do { - | $body - |} while (false); - |return $tmpResult; - """.stripMargin, - foldFunctions = _.map { funcCall => + } else { + val javaDataType = CodeGenerator.javaType(value.dataType) + val valueGen = value.genCode(ctx) + val listGen = list.map(_.genCode(ctx)) + // inTmpResult has 3 possible values: + // -1 means no matches found and there is at least one value in the list evaluated to null + val HAS_NULL = -1 + // 0 means no matches found and all values in the list are not null + val NOT_MATCHED = 0 + // 1 means one value in the list is matched + val MATCHED = 1 + val tmpResult = ctx.freshName("inTmpResult") + val valueArg = ctx.freshName("valueArg") + // All the blocks are meant to be inside a do { ... } while (false); loop. + // The evaluation of variables can be stopped when we find a matching value. + val listCode = listGen.map(x => s""" - |$tmpResult = $funcCall; - |if ($tmpResult == $MATCHED) { + |${x.code} + |if (${x.isNull}) { + | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; + |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { + | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; | continue; |} - """.stripMargin - }.mkString("\n")) - - ev.copy(code = - code""" - |${valueGen.code} - |byte $tmpResult = $HAS_NULL; - |if (!${valueGen.isNull}) { - | $tmpResult = $NOT_MATCHED; - | $javaDataType $valueArg = ${valueGen.value}; - | do { - | $codes - | } while (false); - |} - |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL); - |final boolean ${ev.value} = ($tmpResult == $MATCHED); - """.stripMargin) + """.stripMargin) + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = listCode, + funcName = "valueIn", + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, + makeSplitFunction = body => + s""" + |do { + | $body + |} while (false); + |return $tmpResult; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$tmpResult = $funcCall; + |if ($tmpResult == $MATCHED) { + | continue; + |} + """.stripMargin + }.mkString("\n")) + + ev.copy(code = + code""" + |${valueGen.code} + |byte $tmpResult = $HAS_NULL; + |if (!${valueGen.isNull}) { + | $tmpResult = $NOT_MATCHED; + | $javaDataType $valueArg = ${valueGen.value}; + | do { + | $codes + | } while (false); + |} + |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL); + |final boolean ${ev.value} = ($tmpResult == $MATCHED); + """.stripMargin) + } } override def sql: String = { @@ -607,16 +625,27 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def nullable: Boolean = child.nullable || hasNull final override val nodePatterns: Seq[TreePattern] = Seq(INSET) + private val legacyNullInEmptyBehavior = + SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR) - protected override def nullSafeEval(value: Any): Any = { - if (set.contains(value)) { - true - } else if (isNaN(value)) { - hasNaN - } else if (hasNull) { - null - } else { + override def eval(input: InternalRow): Any = { + if (hset.isEmpty && !legacyNullInEmptyBehavior) { + // IN (empty list) is always false under current behavior. + // Under legacy behavior it's null if the left side is null, otherwise false (SPARK-44550). false + } else { + val value = child.eval(input) + if (value == null) { + null + } else if (set.contains(value)) { + true + } else if (isNaN(value)) { + hasNaN + } else if (hasNull) { + null + } else { + false + } } } @@ -629,7 +658,16 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (canBeComputedUsingSwitch && hset.size <= SQLConf.get.optimizerInSetSwitchThreshold) { + if (hset.isEmpty && !legacyNullInEmptyBehavior) { + // IN (empty list) is always false under current behavior. + // Under legacy behavior it's null if the left side is null, otherwise false (SPARK-44550). + ev.copy(code = + code""" + ${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; + ${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = false; + """ + ) + } else if (canBeComputedUsingSwitch && hset.size <= SQLConf.get.optimizerInSetSwitchThreshold) { genCodeWithSwitch(ctx, ev) } else { genCodeWithSet(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7b44539929c84..8a7f54093d528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -281,11 +281,11 @@ object OptimizeIn extends Rule[LogicalPlan] { _.containsPattern(IN), ruleId) { case q: LogicalPlan => q.transformExpressionsDownWithPruning(_.containsPattern(IN), ruleId) { case In(v, list) if list.isEmpty => + // IN (empty list) is always false under current behavior. + // Under legacy behavior it's null if the left side is null, otherwise false (SPARK-44550). if (!SQLConf.get.getConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR)) { FalseLiteral } else { - // Incorrect legacy behavior optimizes to null if the left side is null, and otherwise - // to false. If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) } case expr @ In(v, list) if expr.inSetConvertible => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 73cc9aca56828..55e0dd2179458 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -134,20 +134,27 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } test("basic IN/INSET predicate test") { - checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), - Literal(2))), null) - checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), - Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkInAndInSet(In(Literal(1), Seq.empty), false) - checkInAndInSet(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkInAndInSet(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), - true) - checkInAndInSet(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), - null) - checkInAndInSet(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkInAndInSet(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkInAndInSet(In(Literal(3), Seq(Literal(1), Literal(2))), false) + Seq(true, false).foreach { legacyNullInBehavior => + withSQLConf(SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> legacyNullInBehavior.toString) { + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + Literal(2))), null) + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), + Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), + expected = if (legacyNullInBehavior) null else false) + checkInAndInSet(In(Literal(1), Seq.empty), false) + checkInAndInSet(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkInAndInSet(In(Literal(1), + Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + true) + checkInAndInSet(In(Literal(2), + Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + null) + checkInAndInSet(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkInAndInSet(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkInAndInSet(In(Literal(3), Seq(Literal(1), Literal(2))), false) + } + } checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql index b893d8970b4d6..cc01887a4e212 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-null-semantics.sql @@ -1,7 +1,7 @@ create temp view v (c) as values (1), (null); create temp view v_empty (e) as select 1 where false; --- Note: tables and temp views hit different optimization/execution codepaths +-- Note: tables and temp views hit different optimization/execution codepaths: expressions over temp views are evaled at query compilation time by ConvertToLocalRelation create table t(c int) using json; insert into t values (1), (null); create table t2(d int) using json; @@ -29,7 +29,6 @@ select null not in (select e from v_empty); -- IN subquery which is not rewritten to join - here we use IN in the ON condition because that is a case that doesn't get rewritten to join in RewritePredicateSubquery, so we can observe the execution behavior of InSubquery directly -- Correct results: column t2.d should be NULL because the ON condition is always false --- This will be fixed by the execution fixes. select * from t left join t2 on (t.c in (select e from t_empty)) is null; select * from t left join t2 on (t.c not in (select e from t_empty)) is null; diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out index 39b03576baaf0..169b49fda846d 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-null-semantics.sql.out @@ -137,7 +137,7 @@ select * from t left join t2 on (t.c in (select e from t_empty)) is null struct -- !query output 1 NULL -NULL 2 +NULL NULL -- !query @@ -146,7 +146,7 @@ select * from t left join t2 on (t.c not in (select e from t_empty)) is null struct -- !query output 1 NULL -NULL 2 +NULL NULL -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala index c9e016c891e77..1534aba28c4ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/EmptyInSuite.scala @@ -49,9 +49,8 @@ with SharedSparkSession { withSQLConf( SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules, SQLConf.LEGACY_NULL_IN_EMPTY_LIST_BEHAVIOR.key -> legacyNullInBehavior.toString) { - // We still get legacy behavior with disableOptimizeIn until execution is also fixed val expectedResultForNullInEmpty = - if (legacyNullInBehavior || disableOptimizeIn) null else false + if (legacyNullInBehavior) null else false val df = t.select(col("a"), col("a").isin(emptylist: _*)) checkAnswer( df, From 3ccb03ec5e7debf8556f38942917f3cbb3c0c631 Mon Sep 17 00:00:00 2001 From: srielau Date: Wed, 9 Aug 2023 16:51:28 +0800 Subject: [PATCH 29/30] [SPARK-42849][SQL] Session Variables ### What changes were proposed in this pull request? We introduce the concept of a temporary VARIABLE. This is a typed and named scalar object that only exists within a session, much like a TEMPORARY VIEW and a TEMPORARY FUNCTION. The purpose of the session variable is to carry intermediate data across multiple statements without bringing the data back into a notebook or a client and then back to the engine. Variables are living in the same namespace as columns. Being defined outside of any query, they are only resolved if a resolution to a column fails within the query (or update, delete, insert, merge). Like other session-level objects, there are no privileges associated. Syntax: DECLARE [OR REPLACE] [VARIABLE] varName [ dataType] [DEFAULT | =] defaultExpr DROP TEMPORARY VARIABLE [IF EXISTS] varname SET {VARIABLE | VAR} { varName = expr } [, ...] SET {VARIABLE | VAR} ( varName [, ...] ) = ( query ) varname := [[`system` .] `session` .] identifier Examples: -- Most fluffy form to create a variable DECLARE OR REPLACE VARIABLE var1 INT DEFAULT RAND(5) * 1000; -- Shortest form to create a variable DECLARE var1 = CURRENT_DATE(); SET VAR var = 10; -- You can reset a variable to its default (which may be non-deterministic such as CURRENt_TIMESTAMP()) SET VAR var = DEFAULT; -- Set multiple variable SET var1 = 6, var2 = (SELECT sum(c1) FROM t); -- Assign a single-row result of a query to set of variables SET VAR (var1, var2) = (SELECT sum(c1), max(c2) FROM t); -- Reference. variable in a query SELECT c1 FROM t WHERE c2 = var1; -- Disambiguate between a column and variable SELECT t.c1, session.c1 FROM T FAQ: - Why don't we use SET? Existing SET is so general (see SET .*? in the grammar) that it has proven (near) impossible to sort out configs, hivevars and variables)... In this PR we try to be nice by redirecting the user to use SET VARIABLE if they try to SET a variable. - Why SYSTEM.SESSION.varName? We want to sort temporary objects into the multipart namespace and have started using SESSION in error messages as a search path. Also long term, we want to support persisted variable definitions which will be multipart names. - Why not CREATE TEMPORARY VARIABLE? DECLARE is the generally accepted (and standardized in SQL/PSM) syntax to define local variables with SQL procedural logic. E.g. in anonymous blocks, procedures, and functions ([BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language)). While we could use CREATE TEMPORARY as an alternative, it adds to the fluff. This is open for debate... We want a short form, and a DECLARE fits the bill. Open issues (Todo): - Block SQL variables from DDL like other temporary objects persisted objects such as VIEWS and DEFAULTs should not depend on them. ### Why are the changes needed? Having a declared and typed SQL variable eliminates several shortcomings of hivevars which are based on text substitution. E.g. protects against typos, type-safety, clean naming, and eliminates the need to push the value to the engine for every statement. ### Does this PR introduce _any_ user-facing change? Yes, see syntax above. ### How was this patch tested? A new sql-session-variable.sql QA testfile is part of this PR Closes #40474 from srielau/SPARK-42849-Session-Variables. Lead-authored-by: srielau Co-authored-by: Wenchen Fan Co-authored-by: Serge Rielau Co-authored-by: Serge Rielau Signed-off-by: Wenchen Fan --- .../main/resources/error/error-classes.json | 146 +- .../apache/spark/SparkThrowableSuite.scala | 6 +- ...rnal-error-metadata-catalog-error-class.md | 52 + ...tions-invalid-default-value-error-class.md | 2 +- ...nditions-invalid-sql-syntax-error-class.md | 5 + ...onditions-unresolved-column-error-class.md | 2 +- ...ditions-unsupported-feature-error-class.md | 4 + docs/sql-error-conditions.md | 60 +- docs/sql-ref-ansi-compliance.md | 5 +- .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 3 + .../sql/catalyst/parser/SqlBaseParser.g4 | 18 +- .../sql/catalyst/analysis/Analyzer.scala | 42 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 +- .../analysis/ColumnResolutionHelper.scala | 87 +- .../catalyst/analysis/ResolveCatalogs.scala | 37 +- ...lveColumnDefaultInCommandInputQuery.scala} | 30 +- .../ResolveReferencesInAggregate.scala | 4 +- .../analysis/ResolveReferencesInSort.scala | 8 +- .../analysis/ResolveReferencesInUpdate.scala | 4 +- .../analysis/ResolveSetVariable.scala | 79 + .../analysis/TableOutputResolver.scala | 38 + .../catalyst/analysis/v2ResolutionPlans.scala | 2 +- .../catalog/TempVariableManager.scala | 68 + .../sql/catalyst/catalog/interface.scala | 22 + .../expressions/VariableReference.scala | 60 + .../sql/catalyst/parser/AstBuilder.scala | 101 +- .../catalyst/plans/logical/v2Commands.scala | 45 + .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../connector/catalog/CatalogManager.scala | 7 +- .../sql/errors/QueryCompilationErrors.scala | 42 +- .../sql/errors/QueryExecutionErrors.scala | 2 +- .../spark/sql/catalyst/SQLKeywordSuite.scala | 4 +- .../analysis/AnalysisErrorSuite.scala | 10 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 50 +- .../spark/sql/execution/SparkPlanner.scala | 2 + .../sql/execution/command/SetCommand.scala | 22 +- .../command/v2/CreateVariableExec.scala | 49 + .../command/v2/DropVariableExec.scala | 55 + .../command/v2/SetVariableExec.scala | 72 + .../command/v2/V2CommandStrategy.scala | 41 + .../spark/sql/execution/command/views.scala | 42 +- .../apache/spark/sql/execution/subquery.scala | 2 +- .../sql-session-variables.sql.out | 2112 +++++++++++++++ .../analyzer-results/table-aliases.sql.out | 14 +- .../inputs/sql-session-variables.sql | 374 +++ .../sql-tests/results/ansi/keywords.sql.out | 3 + .../sql-tests/results/keywords.sql.out | 3 + .../results/sql-session-variables.sql.out | 2286 +++++++++++++++++ .../sql-tests/results/table-aliases.sql.out | 14 +- .../command/AlignAssignmentsSuiteBase.scala | 5 +- .../command/PlanResolutionSuite.scala | 8 +- .../spark/sql/sources/InsertSuite.scala | 4 +- .../ThriftServerWithSparkContextSuite.scala | 2 +- 53 files changed, 6014 insertions(+), 157 deletions(-) create mode 100644 docs/sql-error-conditions-internal-error-metadata-catalog-error-class.md rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/{ResolveColumnDefaultInInsert.scala => ResolveColumnDefaultInCommandInputQuery.scala} (84%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/VariableReference.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala create mode 100644 sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 0ea1eed35e463..75125d2275d1f 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -52,6 +52,12 @@ ], "sqlState" : "22003" }, + "ASSIGNMENT_ARITY_MISMATCH" : { + "message" : [ + "The number of columns or variables assigned or aliased: does not match the number of source expressions: ." + ], + "sqlState" : "42802" + }, "AS_OF_JOIN" : { "message" : [ "Invalid as-of join." @@ -332,7 +338,7 @@ }, "CAST_OVERFLOW_IN_TABLE_INSERT" : { "message" : [ - "Fail to insert a value of type into the type column due to an overflow. Use `try_cast` on the input value to tolerate overflow and return NULL instead." + "Fail to assign a value of type to the type column or variable due to an overflow. Use `try_cast` on the input value to tolerate overflow and return NULL instead." ], "sqlState" : "22003" }, @@ -764,6 +770,13 @@ ], "sqlState" : "42704" }, + "DEFAULT_PLACEMENT_INVALID" : { + "message" : [ + "A DEFAULT keyword in a MERGE, INSERT, UPDATE, or SET VARIABLE command could not be directly assigned to a target column because it was part of an expression.", + "For example: `UPDATE SET c1 = DEFAULT` is allowed, but `UPDATE T SET c1 = DEFAULT + 1` is not allowed." + ], + "sqlState" : "42608" + }, "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : { "message" : [ "Distinct window functions are not supported: ." @@ -791,6 +804,12 @@ "The metric name is not unique: . The same name cannot be used for metrics with different results. However multiple instances of metrics with with same result and name are allowed (e.g. self-joins)." ] }, + "DUPLICATE_ASSIGNMENTS" : { + "message" : [ + "The columns or variables appear more than once as assignment targets." + ], + "sqlState" : "42701" + }, "DUPLICATE_CLAUSES" : { "message" : [ "Found duplicate clauses: . Please, remove one of them." @@ -1227,6 +1246,44 @@ ], "sqlState" : "XX000" }, + "INTERNAL_ERROR_METADATA_CATALOG" : { + "message" : [ + "An object in the metadata catalog has been corrupted:" + ], + "subClass" : { + "SQL_CONFIG" : { + "message" : [ + "Corrupted view SQL configs in catalog." + ] + }, + "TABLE_NAME_CONTEXT" : { + "message" : [ + "Corrupted table name context in catalog: parts expected, but part is missing." + ] + }, + "TEMP_FUNCTION_REFERENCE" : { + "message" : [ + "Corrupted view referred temp functions names in catalog." + ] + }, + "TEMP_VARIABLE_REFERENCE" : { + "message" : [ + "Corrupted view referred temp variable names in catalog." + ] + }, + "TEMP_VIEW_REFERENCE" : { + "message" : [ + "Corrupted view referred temp view names in catalog." + ] + }, + "VIEW_QUERY_COLUMN_ARITY" : { + "message" : [ + "Corrupted view query output column names in catalog: parts expected, but part is missing." + ] + } + }, + "sqlState" : "XX000" + }, "INTERNAL_ERROR_NETWORK" : { "message" : [ "" @@ -1344,7 +1401,7 @@ }, "INVALID_DEFAULT_VALUE" : { "message" : [ - "Failed to execute command because the destination table column has a DEFAULT value ," + "Failed to execute command because the destination column or variable has a DEFAULT value ," ], "subClass" : { "DATA_TYPE" : { @@ -1912,6 +1969,12 @@ "message" : [ "Unsupported function name ." ] + }, + "VARIABLE_TYPE_OR_DEFAULT_REQUIRED" : { + "message" : [ + "The definition of a SQL variable requires either a datatype or a DEFAULT clause.", + "For example, use `DECLARE name STRING` or `DECLARE name = 'SQL'` instead of `DECLARE name`." + ] } }, "sqlState" : "42000" @@ -2497,6 +2560,12 @@ ], "sqlState" : "42883" }, + "ROW_SUBQUERY_TOO_MANY_ROWS" : { + "message" : [ + "More than one row returned by a subquery used as a row." + ], + "sqlState" : "21000" + }, "RULE_ID_NOT_FOUND" : { "message" : [ "Not found an id for the rule name \"\". Please modify RuleIdCollection.scala if you are adding a new rule." @@ -2737,7 +2806,7 @@ }, "UNRESOLVED_COLUMN" : { "message" : [ - "A column or function parameter with name cannot be resolved." + "A column, variable, or function parameter with name cannot be resolved." ], "subClass" : { "WITHOUT_SUGGESTION" : { @@ -2801,6 +2870,12 @@ ], "sqlState" : "42703" }, + "UNRESOLVED_VARIABLE" : { + "message" : [ + "Cannot resolve variable on search path ." + ], + "sqlState" : "42883" + }, "UNSET_NONEXISTENT_PROPERTIES" : { "message" : [ "Attempted to unset non-existent properties [] in table ." @@ -3097,6 +3172,11 @@ " is a reserved table property, ." ] }, + "SET_VARIABLE_USING_SET" : { + "message" : [ + " is a VARIABLE and cannot be updated using the SET statement. Use SET VARIABLE = ... instead." + ] + }, "TABLE_OPERATION" : { "message" : [ "Table does not support . Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by \"spark.sql.catalog\"." @@ -3323,6 +3403,21 @@ "3. set \"spark.sql.legacy.allowUntypedScalaUDF\" to \"true\" and use this API with caution." ] }, + "VARIABLE_ALREADY_EXISTS" : { + "message" : [ + "Cannot create the variable because it already exists.", + "Choose a different name, or drop or replace the existing variable." + ], + "sqlState" : "42723" + }, + "VARIABLE_NOT_FOUND" : { + "message" : [ + "The variable cannot be found. Verify the spelling and correctness of the schema and catalog.", + "If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog.", + "To tolerate the error on drop use DROP VARIABLE IF EXISTS." + ], + "sqlState" : "42883" + }, "VIEW_ALREADY_EXISTS" : { "message" : [ "Cannot create view because it already exists.", @@ -3682,11 +3777,6 @@ "FILTER expression contains window function. It cannot be used in an aggregate function." ] }, - "_LEGACY_ERROR_TEMP_1028" : { - "message" : [ - "Number of column aliases does not match number of columns. Number of column aliases: ; number of columns: ." - ] - }, "_LEGACY_ERROR_TEMP_1030" : { "message" : [ "Window aggregate function with filter predicate is not supported yet." @@ -3878,31 +3968,6 @@ "Number of buckets should be greater than 0 but less than or equal to bucketing.maxBuckets (``). Got ``." ] }, - "_LEGACY_ERROR_TEMP_1084" : { - "message" : [ - "Corrupted table name context in catalog: parts expected, but part is missing." - ] - }, - "_LEGACY_ERROR_TEMP_1085" : { - "message" : [ - "Corrupted view SQL configs in catalog." - ] - }, - "_LEGACY_ERROR_TEMP_1086" : { - "message" : [ - "Corrupted view query output column names in catalog: parts expected, but part is missing." - ] - }, - "_LEGACY_ERROR_TEMP_1087" : { - "message" : [ - "Corrupted view referred temp view names in catalog." - ] - }, - "_LEGACY_ERROR_TEMP_1088" : { - "message" : [ - "Corrupted view referred temp functions names in catalog." - ] - }, "_LEGACY_ERROR_TEMP_1089" : { "message" : [ "Column statistics deserialization is not supported for column of data type: ." @@ -4785,21 +4850,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_1339" : { - "message" : [ - "Failed to execute INSERT INTO command because the VALUES list contains a DEFAULT column reference as part of another expression; this is not allowed." - ] - }, - "_LEGACY_ERROR_TEMP_1340" : { - "message" : [ - "Failed to execute UPDATE command because the SET list contains a DEFAULT column reference as part of another expression; this is not allowed." - ] - }, - "_LEGACY_ERROR_TEMP_1343" : { - "message" : [ - "Failed to execute MERGE INTO command because one of its INSERT or UPDATE assignments contains a DEFAULT column reference as part of another expression; this is not allowed." - ] - }, "_LEGACY_ERROR_TEMP_1344" : { "message" : [ "Invalid DEFAULT value for column : fails to parse as a valid literal value." diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 0249cde54884b..57c4fe31b3b92 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -383,7 +383,7 @@ class SparkThrowableSuite extends SparkFunSuite { "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`foo`", "proposal" -> "`bar`, `baz`") ) == - "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with " + + "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with " + "name `foo` cannot be resolved. Did you mean one of the following? [`bar`, `baz`]." ) @@ -395,7 +395,7 @@ class SparkThrowableSuite extends SparkFunSuite { "proposal" -> "`bar`, `baz`"), "" ) == - "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with " + + "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with " + "name `foo` cannot be resolved. Did you mean one of the following? [`bar`, `baz`]." ) } @@ -406,7 +406,7 @@ class SparkThrowableSuite extends SparkFunSuite { "UNRESOLVED_COLUMN.WITH_SUGGESTION", Map("objectName" -> "`foo`", "proposal" -> "`${bar}`, `baz`") ) == - "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with " + + "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, variable, or function parameter with " + "name `foo` cannot be resolved. Did you mean one of the following? [`${bar}`, `baz`]." ) } diff --git a/docs/sql-error-conditions-internal-error-metadata-catalog-error-class.md b/docs/sql-error-conditions-internal-error-metadata-catalog-error-class.md new file mode 100644 index 0000000000000..e451165612814 --- /dev/null +++ b/docs/sql-error-conditions-internal-error-metadata-catalog-error-class.md @@ -0,0 +1,52 @@ +--- +layout: global +title: INTERNAL_ERROR_METADATA_CATALOG error class +displayTitle: INTERNAL_ERROR_METADATA_CATALOG error class +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +[SQLSTATE: XX000](sql-error-conditions-sqlstates.html#class-XX-internal-error) + +An object in the metadata catalog has been corrupted: + +This error class has the following derived error classes: + +## SQL_CONFIG + +Corrupted view SQL configs in catalog. + +## TABLE_NAME_CONTEXT + +Corrupted table name context in catalog: `` parts expected, but part `` is missing. + +## TEMP_FUNCTION_REFERENCE + +Corrupted view referred temp functions names in catalog. + +## TEMP_VARIABLE_REFERENCE + +Corrupted view referred temp variable names in catalog. + +## TEMP_VIEW_REFERENCE + +Corrupted view referred temp view names in catalog. + +## VIEW_QUERY_COLUMN_ARITY + +Corrupted view query output column names in catalog: `` parts expected, but part `` is missing. + + diff --git a/docs/sql-error-conditions-invalid-default-value-error-class.md b/docs/sql-error-conditions-invalid-default-value-error-class.md index 466b5a9274cad..05c5680fc953c 100644 --- a/docs/sql-error-conditions-invalid-default-value-error-class.md +++ b/docs/sql-error-conditions-invalid-default-value-error-class.md @@ -21,7 +21,7 @@ license: | SQLSTATE: none assigned -Failed to execute `` command because the destination table column `` has a DEFAULT value ``, +Failed to execute `` command because the destination column or variable `` has a DEFAULT value ``, This error class has the following derived error classes: diff --git a/docs/sql-error-conditions-invalid-sql-syntax-error-class.md b/docs/sql-error-conditions-invalid-sql-syntax-error-class.md index b1e298f7b908b..d9be7bad10320 100644 --- a/docs/sql-error-conditions-invalid-sql-syntax-error-class.md +++ b/docs/sql-error-conditions-invalid-sql-syntax-error-class.md @@ -97,4 +97,9 @@ Cannot resolve window reference ``. Unsupported function name ``. +## VARIABLE_TYPE_OR_DEFAULT_REQUIRED + +The definition of a SQL variable requires either a datatype or a DEFAULT clause. +For example, use `DECLARE name STRING` or `DECLARE name = 'SQL'` instead of `DECLARE name`. + diff --git a/docs/sql-error-conditions-unresolved-column-error-class.md b/docs/sql-error-conditions-unresolved-column-error-class.md index bdda298d30189..cb7b9d4e1d29a 100644 --- a/docs/sql-error-conditions-unresolved-column-error-class.md +++ b/docs/sql-error-conditions-unresolved-column-error-class.md @@ -21,7 +21,7 @@ license: | [SQLSTATE: 42703](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -A column or function parameter with name `` cannot be resolved. +A column, variable, or function parameter with name `` cannot be resolved. This error class has the following derived error classes: diff --git a/docs/sql-error-conditions-unsupported-feature-error-class.md b/docs/sql-error-conditions-unsupported-feature-error-class.md index 7a60dc76fa640..790b5c88e461d 100644 --- a/docs/sql-error-conditions-unsupported-feature-error-class.md +++ b/docs/sql-error-conditions-unsupported-feature-error-class.md @@ -177,6 +177,10 @@ set PROPERTIES and DBPROPERTIES at the same time. `` is a reserved table property, ``. +## SET_VARIABLE_USING_SET + +`` is a VARIABLE and cannot be updated using the SET statement. Use SET VARIABLE `` = ... instead. + ## TABLE_OPERATION Table `` does not support ``. Please check the current catalog and namespace to make sure the qualified table name is expected, and also check the catalog implementation which is configured by "spark.sql.catalog". diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index b59bb1789488e..bd49ba94f5fb6 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -79,6 +79,12 @@ Ambiguous reference to the field ``. It appears `` times in the sc ``.`` If necessary set `` to "false" to bypass this error. +### ASSIGNMENT_ARITY_MISMATCH + +[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +The number of columns or variables assigned or aliased: `` does not match the number of source expressions: ``. + ### [AS_OF_JOIN](sql-error-conditions-as-of-join-error-class.html) SQLSTATE: none assigned @@ -293,7 +299,7 @@ The value `` of the type `` cannot be cast to `` [SQLSTATE: 22003](sql-error-conditions-sqlstates.html#class-22-data-exception) -Fail to insert a value of `` type into the `` type column `` due to an overflow. Use `try_cast` on the input value to tolerate overflow and return NULL instead. +Fail to assign a value of `` type to the `` type column or variable `` due to an overflow. Use `try_cast` on the input value to tolerate overflow and return NULL instead. ### CODEC_NOT_AVAILABLE @@ -422,6 +428,13 @@ Decimal precision `` exceeds max precision ``. Default database `` does not exist, please create it first or change default database to ````. +### DEFAULT_PLACEMENT_INVALID + +[SQLSTATE: 42608](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +A DEFAULT keyword in a MERGE, INSERT, UPDATE, or SET VARIABLE command could not be directly assigned to a target column because it was part of an expression. +For example: `UPDATE SET c1 = DEFAULT` is allowed, but `UPDATE T SET c1 = DEFAULT + 1` is not allowed. + ### DISTINCT_WINDOW_FUNCTION_UNSUPPORTED SQLSTATE: none assigned @@ -452,6 +465,12 @@ SQLSTATE: none assigned The metric name is not unique: ``. The same name cannot be used for metrics with different results. However multiple instances of metrics with with same result and name are allowed (e.g. self-joins). +### DUPLICATE_ASSIGNMENTS + +[SQLSTATE: 42701](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +The columns or variables `` appear more than once as assignment targets. + ### DUPLICATE_CLAUSES SQLSTATE: none assigned @@ -774,6 +793,14 @@ For more details see [INSUFFICIENT_TABLE_PROPERTY](sql-error-conditions-insuffic `` +### [INTERNAL_ERROR_METADATA_CATALOG](sql-error-conditions-internal-error-metadata-catalog-error-class.html) + +[SQLSTATE: XX000](sql-error-conditions-sqlstates.html#class-XX-internal-error) + +An object in the metadata catalog has been corrupted: + +For more details see [INTERNAL_ERROR_METADATA_CATALOG](sql-error-conditions-internal-error-metadata-catalog-error-class.html) + ### INTERNAL_ERROR_NETWORK [SQLSTATE: XX000](sql-error-conditions-sqlstates.html#class-XX-internal-error) @@ -866,7 +893,7 @@ For more details see [INVALID_CURSOR](sql-error-conditions-invalid-cursor-error- SQLSTATE: none assigned -Failed to execute `` command because the destination table column `` has a DEFAULT value ``, +Failed to execute `` command because the destination column or variable `` has a DEFAULT value ``, For more details see [INVALID_DEFAULT_VALUE](sql-error-conditions-invalid-default-value-error-class.html) @@ -1610,6 +1637,12 @@ The function `` cannot be found. Verify the spelling and correctnes If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog. To tolerate the error on drop use DROP FUNCTION IF EXISTS. +### ROW_SUBQUERY_TOO_MANY_ROWS + +[SQLSTATE: 21000](sql-error-conditions-sqlstates.html#class-21-cardinality-violation) + +More than one row returned by a subquery used as a row. + ### RULE_ID_NOT_FOUND [SQLSTATE: 22023](sql-error-conditions-sqlstates.html#class-22-data-exception) @@ -1868,7 +1901,7 @@ Cannot infer grouping columns for GROUP BY ALL based on the select clause. Pleas [SQLSTATE: 42703](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) -A column or function parameter with name `` cannot be resolved. +A column, variable, or function parameter with name `` cannot be resolved. For more details see [UNRESOLVED_COLUMN](sql-error-conditions-unresolved-column-error-class.html) @@ -1900,6 +1933,12 @@ Cannot resolve function `` on search path ``. USING column `` cannot be resolved on the `` side of the join. The ``-side columns: [``]. +### UNRESOLVED_VARIABLE + +[SQLSTATE: 42883](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot resolve variable `` on search path ``. + ### UNSET_NONEXISTENT_PROPERTIES SQLSTATE: none assigned @@ -2056,6 +2095,21 @@ You're using untyped Scala UDF, which does not have the input type information. 2. use Java UDF APIs, e.g. `udf(new UDF1[String, Integer] { override def call(s: String): Integer = s.length() }, IntegerType)`, if input types are all non primitive. 3. set "spark.sql.legacy.allowUntypedScalaUDF" to "true" and use this API with caution. +### VARIABLE_ALREADY_EXISTS + +[SQLSTATE: 42723](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot create the variable `` because it already exists. +Choose a different name, or drop or replace the existing variable. + +### VARIABLE_NOT_FOUND + +[SQLSTATE: 42883](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +The variable `` cannot be found. Verify the spelling and correctness of the schema and catalog. +If you did not qualify the name with a schema and catalog, verify the current_schema() output, or qualify the name with the correct schema and catalog. +To tolerate the error on drop use DROP VARIABLE IF EXISTS. + ### VIEW_ALREADY_EXISTS [SQLSTATE: 42P07](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 84af522ad2185..f3a0e8f9afbf0 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -349,7 +349,7 @@ By default, both `spark.sql.ansi.enabled` and `spark.sql.ansi.enforceReservedKey Below is a list of all the keywords in Spark SQL. |Keyword|Spark SQL
ANSI Mode|Spark SQL
Default Mode|SQL-2016| -|------|----------------------|-------------------------|--------| +|--|----------------------|-------------------------|--------| |ADD|non-reserved|non-reserved|non-reserved| |AFTER|non-reserved|non-reserved|non-reserved| |ALL|reserved|non-reserved|reserved| @@ -423,6 +423,7 @@ Below is a list of all the keywords in Spark SQL. |DBPROPERTIES|non-reserved|non-reserved|non-reserved| |DEC|non-reserved|non-reserved|reserved| |DECIMAL|non-reserved|non-reserved|reserved| +|DECLARE|non-reserved|non-reserved|non-reserved| |DEFAULT|non-reserved|non-reserved|non-reserved| |DEFINED|non-reserved|non-reserved|non-reserved| |DELETE|non-reserved|non-reserved|reserved| @@ -667,6 +668,8 @@ Below is a list of all the keywords in Spark SQL. |USING|reserved|strict-non-reserved|reserved| |VALUES|non-reserved|non-reserved|reserved| |VARCHAR|non-reserved|non-reserved|reserved| +|VAR|non-reserved|non-reserved|non-reserved| +|VARIABLE|non-reserved|non-reserved|non-reserved| |VERSION|non-reserved|non-reserved|non-reserved| |VIEW|non-reserved|non-reserved|non-reserved| |VIEWS|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 74e8ee1ecf9fe..bf6370575a1ee 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -161,6 +161,7 @@ DATE_DIFF: 'DATE_DIFF'; DBPROPERTIES: 'DBPROPERTIES'; DEC: 'DEC'; DECIMAL: 'DECIMAL'; +DECLARE: 'DECLARE'; DEFAULT: 'DEFAULT'; DEFINED: 'DEFINED'; DELETE: 'DELETE'; @@ -404,6 +405,8 @@ USER: 'USER'; USING: 'USING'; VALUES: 'VALUES'; VARCHAR: 'VARCHAR'; +VAR: 'VAR'; +VARIABLE: 'VARIABLE'; VERSION: 'VERSION'; VIEW: 'VIEW'; VIEWS: 'VIEWS'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 1ea0f6e583d2c..a45ebee310682 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -165,7 +165,10 @@ statement | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? identifierReference AS className=stringLit (USING resource (COMMA resource)*)? #createFunction - | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? identifierReference #dropFunction + | DECLARE (OR REPLACE)? VARIABLE? + identifierReference dataType? variableDefaultExpression? #createVariable + | DROP TEMPORARY VARIABLE (IF EXISTS)? identifierReference #dropVariable | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? statement #explain | SHOW TABLES ((FROM | IN) identifierReference)? @@ -210,6 +213,9 @@ statement | SET TIME ZONE interval #setTimeZone | SET TIME ZONE timezone #setTimeZone | SET TIME ZONE .*? #setTimeZone + | SET (VARIABLE | VAR) assignmentList #setVariable + | SET (VARIABLE | VAR) LEFT_PAREN multipartIdentifierList RIGHT_PAREN EQ + LEFT_PAREN query RIGHT_PAREN #setVariable | SET configKey EQ configValue #setQuotedConfiguration | SET configKey (EQ .*?)? #setConfiguration | SET .*? EQ configValue #setQuotedConfiguration @@ -1109,6 +1115,10 @@ defaultExpression : DEFAULT expression ; +variableDefaultExpression + : (DEFAULT | EQ) expression + ; + colTypeList : colType (COMMA colType)* ; @@ -1335,6 +1345,7 @@ ansiNonReserved | DBPROPERTIES | DEC | DECIMAL + | DECLARE | DEFAULT | DEFINED | DELETE @@ -1525,6 +1536,8 @@ ansiNonReserved | USE | VALUES | VARCHAR + | VAR + | VARIABLE | VERSION | VIEW | VIEWS @@ -1640,6 +1653,7 @@ nonReserved | DBPROPERTIES | DEC | DECIMAL + | DECLARE | DEFAULT | DEFINED | DELETE @@ -1869,6 +1883,8 @@ nonReserved | USER | VALUES | VARCHAR + | VAR + | VARIABLE | VERSION | VIEW | VIEWS diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0c792ded8f890..e4bf5f44ea20e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -136,6 +136,7 @@ case class AnalysisContext( // 2. If we are not resolving a view, this field will be updated everytime the analyzer // lookup a temporary function. And export to the view metadata. referredTempFunctionNames: mutable.Set[String] = mutable.Set.empty, + referredTempVariableNames: Seq[Seq[String]] = Seq.empty, outerPlan: Option[LogicalPlan] = None) object AnalysisContext { @@ -162,7 +163,8 @@ object AnalysisContext { maxNestedViewDepth, originContext.relationCache, viewDesc.viewReferredTempViewNames, - mutable.Set(viewDesc.viewReferredTempFunctionNames: _*)) + mutable.Set(viewDesc.viewReferredTempFunctionNames: _*), + viewDesc.viewReferredTempVariableNames) set(context) try f finally { set(originContext) } } @@ -284,7 +286,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveFieldNameAndPosition :: AddMetadataColumns :: DeduplicateRelations :: - ResolveReferences :: + new ResolveReferences(catalogManager) :: // Please do not insert any other rules in between. See the TODO comments in rule // ResolveLateralColumnAliasReference for more details. ResolveLateralColumnAliasReference :: @@ -307,6 +309,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: ResolveOutputRelation :: + new ResolveSetVariable(catalogManager) :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -1455,6 +1458,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * Some plan nodes have special column reference resolution logic, please read these sub-rules for * details: * - [[ResolveReferencesInAggregate]] + * - [[ResolveReferencesInUpdate]] * - [[ResolveReferencesInSort]] * * Note: even if we use a single rule to resolve columns, it's still non-trivial to have a @@ -1463,7 +1467,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * previous options are permanently not applicable. If the current option can be applicable * in the next iteration (other rules update the plan), we should not try the next option. */ - object ResolveReferences extends Rule[LogicalPlan] with ColumnResolutionHelper { + class ResolveReferences(val catalogManager: CatalogManager) + extends Rule[LogicalPlan] with ColumnResolutionHelper { + + private val resolveColumnDefaultInCommandInputQuery = + new ResolveColumnDefaultInCommandInputQuery(catalogManager) + private val resolveReferencesInAggregate = + new ResolveReferencesInAggregate(catalogManager) + private val resolveReferencesInUpdate = + new ResolveReferencesInUpdate(catalogManager) + private val resolveReferencesInSort = + new ResolveReferencesInSort(catalogManager) /** * Return true if there're conflicting attributes among children's outputs of a plan @@ -1491,7 +1505,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { // Don't wait other rules to resolve the child plans of `InsertIntoStatement` as we need // to resolve column "DEFAULT" in the child plans so that they must be unresolved. - case i: InsertIntoStatement => ResolveColumnDefaultInInsert(i) + case i: InsertIntoStatement => resolveColumnDefaultInCommandInputQuery(i) + + // Don't wait other rules to resolve the child plans of `SetVariable` as we need + // to resolve column "DEFAULT" in the child plans so that they must be unresolved. + case s: SetVariable => resolveColumnDefaultInCommandInputQuery(s) // Wait for other rules to resolve child plans first case p: LogicalPlan if !p.childrenResolved => p @@ -1596,7 +1614,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan - case a: Aggregate => ResolveReferencesInAggregate(a) + case a: Aggregate => resolveReferencesInAggregate(a) // Special case for Project as it supports lateral column alias. case p: Project => @@ -1605,14 +1623,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // Lateral column alias has higher priority than outer reference. val resolvedWithLCA = resolveLateralColumnAlias(resolvedNoOuter) val resolvedWithOuter = resolvedWithLCA.map(resolveOuterRef) - p.copy(projectList = resolvedWithOuter.map(_.asInstanceOf[NamedExpression])) + val resolvedWithVariables = resolvedWithOuter.map(p => resolveVariables(p)) + p.copy(projectList = resolvedWithVariables.map(_.asInstanceOf[NamedExpression])) case o: OverwriteByExpression if o.table.resolved => // The delete condition of `OverwriteByExpression` will be passed to the table // implementation and should be resolved based on the table schema. o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table)) - case u: UpdateTable => ResolveReferencesInUpdate(u) + case u: UpdateTable => resolveReferencesInUpdate(u) case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved => @@ -1712,7 +1731,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val resolvedNoOuter = partitionExprs.map(resolveExpressionByPlanChildren(_, r)) val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(resolvedNoOuter, child) // Outer reference has lower priority than this. See the doc of `ResolveReferences`. - val finalPartitionExprs = newPartitionExprs.map(resolveOuterRef) + val resolvedWithOuter = newPartitionExprs.map(resolveOuterRef) + val finalPartitionExprs = resolvedWithOuter.map(e => resolveVariables(e)) if (child.output == newChild.output) { r.copy(finalPartitionExprs, newChild) } else { @@ -1727,7 +1747,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val resolvedWithAgg = resolveColWithAgg(resolvedNoOuter, child) val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child) // Outer reference has lowermost priority. See the doc of `ResolveReferences`. - val finalCond = resolveOuterRef(newCond.head) + val resolvedWithOuter = resolveOuterRef(newCond.head) + val finalCond = resolveVariables(resolvedWithOuter) if (child.output == newChild.output) { f.copy(condition = finalCond) } else { @@ -1736,7 +1757,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Project(child.output, newFilter) } - case s: Sort if !s.resolved || s.missingInput.nonEmpty => ResolveReferencesInSort(s) + case s: Sort if !s.resolved || s.missingInput.nonEmpty => + resolveReferencesInSort(s) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 8b04c8108bd08..fee5660017c7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, PLAN_EXPRESSION, UNRESOLVED_WINDOW_EXPRESSION} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -259,7 +259,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Fail if we still have an unresolved all in group by. This needs to run before the // general unresolved check below to throw a more tailored error message. - ResolveReferencesInAggregate.checkUnresolvedGroupByAll(operator) + new ResolveReferencesInAggregate(catalogManager).checkUnresolvedGroupByAll(operator) getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => @@ -650,6 +650,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case alter: AlterTableCommand => checkAlterTableCommand(alter) + case c: CreateVariable + if c.resolved && c.defaultExpr.child.containsPattern(PLAN_EXPRESSION) => + val ident = c.name.asInstanceOf[ResolvedIdentifier] + val varName = toSQLId( + ident.catalog.name +: ident.identifier.namespace :+ ident.identifier.name) + throw QueryCompilationErrors.defaultValuesMayNotContainSubQueryExpressions( + "CRETE VARIABLE", + varName, + c.defaultExpr.originalSQL) + case _ => // Falls back to the following checks } @@ -753,6 +763,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] && !o.isInstanceOf[Expand] && !o.isInstanceOf[Generate] && + !o.isInstanceOf[CreateVariable] && // Lateral join is checked in checkSubqueryExpression. !o.isInstanceOf[LateralJoin] => // The rule above is used to check Aggregate operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 98cbdea72d53b..b631d1fd8b6d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -36,6 +37,8 @@ trait ColumnResolutionHelper extends Logging { def conf: SQLConf + def catalogManager: CatalogManager + /** * This method tries to resolve expressions and find missing attributes recursively. * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes @@ -192,7 +195,8 @@ trait ColumnResolutionHelper extends Logging { try { val resolved = innerResolve(expr, isTopLevel = true) - if (allowOuter) resolveOuterRef(resolved) else resolved + val withOuterResolved = if (allowOuter) resolveOuterRef(resolved) else resolved + resolveVariables(withOuterResolved) } catch { case ae: AnalysisException if !throws => logDebug(ae.getMessage) @@ -233,6 +237,87 @@ trait ColumnResolutionHelper extends Logging { } } + def lookupVariable(nameParts: Seq[String]): Option[VariableReference] = { + // The temp variables live in `SYSTEM.SESSION`, and the name can be qualified or not. + def maybeTempVariableName(nameParts: Seq[String]): Boolean = { + nameParts.length == 1 || { + if (nameParts.length == 2) { + nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) + } else if (nameParts.length == 3) { + nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) + } else { + false + } + } + } + + if (maybeTempVariableName(nameParts)) { + val variableName = if (conf.caseSensitiveAnalysis) { + nameParts.last + } else { + nameParts.last.toLowerCase(Locale.ROOT) + } + catalogManager.tempVariableManager.get(variableName).map { varDef => + VariableReference( + nameParts, + FakeSystemCatalog, + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), + varDef) + } + } else { + None + } + } + + // Resolves `UnresolvedAttribute` to its value. + protected def resolveVariables(e: Expression): Expression = { + def resolveVariable(nameParts: Seq[String]): Option[Expression] = { + val isResolvingView = AnalysisContext.get.catalogAndNamespace.nonEmpty + if (isResolvingView) { + if (AnalysisContext.get.referredTempVariableNames.contains(nameParts)) { + lookupVariable(nameParts) + } else { + None + } + } else { + lookupVariable(nameParts) + } + } + + def resolve(nameParts: Seq[String]): Option[Expression] = { + var resolvedVariable: Option[Expression] = None + // We only support temp variables for now, so the variable name can at most have 3 parts. + var numInnerFields: Int = math.max(0, nameParts.length - 3) + // Follow the column resolution and prefer the longest match. This makes sure that users + // can always use fully qualified variable name to avoid name conflicts. + while (resolvedVariable.isEmpty && numInnerFields < nameParts.length) { + resolvedVariable = resolveVariable(nameParts.dropRight(numInnerFields)) + if (resolvedVariable.isEmpty) numInnerFields += 1 + } + + resolvedVariable.map { variable => + if (numInnerFields != 0) { + val nestedFields = nameParts.takeRight(numInnerFields) + nestedFields.foldLeft(variable: Expression) { (e, name) => + ExtractValue(e, Literal(name), conf.resolver) + } + } else { + variable + } + }.map(e => Alias(e, nameParts.last)()) + } + + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + case u: UnresolvedAttribute => + resolve(u.nameParts).getOrElse(u) + // Re-resolves `TempResolvedColumn` as variable references if it has tried to be resolved with + // Aggregate but failed. + case t: TempResolvedColumn if t.hasTried => + resolve(t.nameParts).getOrElse(t) + } + } + // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an // `Aggregate`. If `TempResolvedColumn` doesn't end up as aggregate function input or grouping // column, we will undo the column resolution later to avoid confusing error message. E,g,, if diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 221f1a0f3135c..788f79cde99e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, LookupCatalog} +import org.apache.spark.sql.errors.QueryCompilationErrors /** * Resolves the catalog of the name parts for table/view/function/namespace. @@ -27,7 +28,16 @@ import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, Looku class ResolveCatalogs(val catalogManager: CatalogManager) extends Rule[LogicalPlan] with LookupCatalog { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + // We only support temp variables for now and the system catalog is not properly implemented + // yet. We need to resolve `UnresolvedIdentifier` for variable commands specially. + case c @ CreateVariable(UnresolvedIdentifier(nameParts, _), _, _) => + val resolved = resolveVariableName(nameParts) + c.copy(name = resolved) + case d @ DropVariable(UnresolvedIdentifier(nameParts, _), _) => + val resolved = resolveVariableName(nameParts) + d.copy(name = resolved) + case UnresolvedIdentifier(nameParts, allowTemp) => if (allowTemp && catalogManager.v1SessionCatalog.isTempView(nameParts)) { val ident = Identifier.of(nameParts.dropRight(1).toArray, nameParts.last) @@ -51,4 +61,29 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case UnresolvedNamespace(CatalogAndNamespace(catalog, ns)) => ResolvedNamespace(catalog, ns) } + + private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { + def ident: Identifier = Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), nameParts.last) + if (nameParts.length == 1) { + ResolvedIdentifier(FakeSystemCatalog, ident) + } else if (nameParts.length == 2) { + if (nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + ResolvedIdentifier(FakeSystemCatalog, ident) + } else { + throw QueryCompilationErrors.unresolvedVariableError( + nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) + } + } else if (nameParts.length == 3) { + if (nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + ResolvedIdentifier(FakeSystemCatalog, ident) + } else { + throw QueryCompilationErrors.unresolvedVariableError( + nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) + } + } else { + throw QueryCompilationErrors.unresolvedVariableError( + nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInCommandInputQuery.scala similarity index 84% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInCommandInputQuery.scala index f791966492666..47cf817c2a862 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInCommandInputQuery.scala @@ -18,29 +18,31 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, VariableReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExprOrNullLit, isExplicitDefaultColumn} +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructField /** * A virtual rule to resolve column "DEFAULT" in [[Project]] and [[UnresolvedInlineTable]] under - * [[InsertIntoStatement]]. It's only used by the real rule `ResolveReferences`. + * [[InsertIntoStatement]] and [[SetVariable]]. It's only used by the real rule `ResolveReferences`. * * This virtual rule is triggered if: * 1. The column "DEFAULT" can't be resolved normally by `ResolveReferences`. This is guaranteed as * `ResolveReferences` resolves the query plan bottom up. This means that when we reach here to - * resolve [[InsertIntoStatement]], its child plans have already been resolved by - * `ResolveReferences`. - * 2. The plan nodes between [[Project]] and [[InsertIntoStatement]] are - * all unary nodes that inherit the output columns from its child. - * 3. The plan nodes between [[UnresolvedInlineTable]] and [[InsertIntoStatement]] are either + * resolve the command, its child plans have already been resolved by `ResolveReferences`. + * 2. The plan nodes between [[Project]] and command are all unary nodes that inherit the + * output columns from its child. + * 3. The plan nodes between [[UnresolvedInlineTable]] and command are either * [[Project]], or [[Aggregate]], or [[SubqueryAlias]]. */ -case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolutionHelper { +class ResolveColumnDefaultInCommandInputQuery(val catalogManager: CatalogManager) + extends SQLConfHelper with ColumnResolutionHelper { + // TODO (SPARK-43752): support v2 write commands as well. def apply(plan: LogicalPlan): LogicalPlan = plan match { case i: InsertIntoStatement if conf.enableDefaultColumns && i.table.resolved && @@ -76,6 +78,18 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu } } + case s: SetVariable if s.targetVariables.forall(_.isInstanceOf[VariableReference]) && + s.sourceQuery.containsPattern(UNRESOLVED_ATTRIBUTE) => + val expectedQuerySchema = s.targetVariables.map { + case v: VariableReference => + StructField(v.identifier.name, v.dataType, v.nullable) + .withCurrentDefaultValue(v.varDef.defaultValueSQL) + } + // We match the query schema with the SET variable schema by position. If the n-th + // column of the query is the DEFAULT column, we should get the default value expression + // defined for the n-th variable of the SET. + s.withNewChildren(Seq(resolveColumnDefault(s.sourceQuery, expectedQuerySchema))) + case _ => plan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 09ae87b071fdd..6bc1949a4e0c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expres import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} +import org.apache.spark.sql.connector.catalog.CatalogManager /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[Aggregate]]. It's only used by the real @@ -46,8 +47,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REF * 5. Resolves the columns to outer references with the outer plan if we are resolving subquery * expressions. */ -object ResolveReferencesInAggregate extends SQLConfHelper +class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends SQLConfHelper with ColumnResolutionHelper with AliasHelper { + def apply(a: Aggregate): Aggregate = { val planForResolve = a.child match { // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala index 54044932d9e3b..e4e9188662a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} +import org.apache.spark.sql.connector.catalog.CatalogManager /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[Sort]]. It's only used by the real @@ -45,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} * Note, 3 and 4 are actually orthogonal. If the child plan is Aggregate, 4 can only resolve columns * as the grouping columns, which is completely covered by 3. */ -object ResolveReferencesInSort extends SQLConfHelper with ColumnResolutionHelper { +class ResolveReferencesInSort(val catalogManager: CatalogManager) + extends SQLConfHelper with ColumnResolutionHelper { def apply(s: Sort): LogicalPlan = { val resolvedNoOuter = s.order.map(resolveExpressionByPlanOutput(_, s.child)) @@ -53,7 +55,9 @@ object ResolveReferencesInSort extends SQLConfHelper with ColumnResolutionHelper val (missingAttrResolved, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, s.child) val orderByAllResolved = resolveOrderByAll( s.global, newChild, missingAttrResolved.map(_.asInstanceOf[SortOrder])) - val finalOrdering = orderByAllResolved.map(e => resolveOuterRef(e).asInstanceOf[SortOrder]) + val resolvedWithOuter = orderByAllResolved.map(e => resolveOuterRef(e)) + val finalOrdering = resolvedWithOuter.map(e => resolveVariables(e) + .asInstanceOf[SortOrder]) if (s.child.output == newChild.output) { s.copy(order = finalOrdering) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index ead323ce9857b..92813a156988b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.resolveColumnDefaultInAssignmentValue +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors /** @@ -32,7 +33,8 @@ import org.apache.spark.sql.errors.QueryCompilationErrors * 3. Resolves the column to the default value expression, if the column is the assignment value * and the corresponding assignment key is a top-level column. */ -case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutionHelper { +class ResolveReferencesInUpdate(val catalogManager: CatalogManager) + extends SQLConfHelper with ColumnResolutionHelper { def apply(u: UpdateTable): UpdateTable = { assert(u.table.resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala new file mode 100644 index 0000000000000..ebf56ef1cc4f3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan, SetVariable} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId +import org.apache.spark.sql.errors.QueryCompilationErrors.unresolvedVariableError +import org.apache.spark.sql.types.IntegerType + +/** + * Resolves the target SQL variables that we want to set in SetVariable, and add cast if necessary + * to make the assignment valid. + */ +class ResolveSetVariable(val catalogManager: CatalogManager) extends Rule[LogicalPlan] + with ColumnResolutionHelper { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(COMMAND), ruleId) { + // Resolve the left hand side of the SET VAR command + case setVariable: SetVariable if !setVariable.targetVariables.forall(_.resolved) => + val resolvedVars = setVariable.targetVariables.map { + case u: UnresolvedAttribute => + lookupVariable(u.nameParts) match { + case Some(variable) => variable.copy(canFold = false) + case _ => throw unresolvedVariableError(u.nameParts, Seq("SYSTEM", "SESSION")) + } + + case other => throw SparkException.internalError( + "Unexpected target variable expression in SetVariable: " + other) + } + + // Protect against duplicate variable names + // Names are normalized when the variables are created. + // No need for case insensitive comparison here. + // TODO: we need to group by the qualified variable name once other catalogs support it. + val dups = resolvedVars.groupBy(_.identifier.name).filter(kv => kv._2.length > 1) + if (dups.nonEmpty) { + throw new AnalysisException( + errorClass = "DUPLICATE_ASSIGNMENTS", + messageParameters = Map("nameList" -> dups.keys.map(toSQLId).mkString(", "))) + } + + setVariable.copy(targetVariables = resolvedVars) + + case setVariable: SetVariable + if setVariable.targetVariables.forall(_.isInstanceOf[VariableReference]) && + setVariable.sourceQuery.resolved => + val targetVariables = setVariable.targetVariables.map(_.asInstanceOf[VariableReference]) + val withCasts = TableOutputResolver.resolveVariableOutputColumns( + targetVariables, setVariable.sourceQuery, conf) + val withLimit = if (withCasts.maxRows.exists(_ <= 2)) { + withCasts + } else { + Limit(Literal(2, IntegerType), withCasts) + } + setVariable.copy(sourceQuery = withLimit) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 894cd0b39911f..51f275f50dc5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} @@ -34,6 +35,43 @@ import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegralType, MapType, StructType} object TableOutputResolver { + + def resolveVariableOutputColumns( + expected: Seq[VariableReference], + query: LogicalPlan, + conf: SQLConf): LogicalPlan = { + + if (expected.size != query.output.size) { + throw new AnalysisException( + errorClass = "ASSIGNMENT_ARITY_MISMATCH", + messageParameters = Map( + "numTarget" -> expected.size.toString, + "numExpr" -> query.output.size.toString)) + } + + val resolved: Seq[NamedExpression] = { + query.output.zip(expected).map { case (inputCol, expected) => + if (DataTypeUtils.sameType(inputCol.dataType, expected.dataType)) { + inputCol + } else { + // SET VAR always uses the ANSI store assignment policy + val cast = Cast( + inputCol, + expected.dataType, + Option(conf.sessionLocalTimeZone), + ansiEnabled = true) + Alias(cast, expected.identifier.name)() + } + } + } + + if (resolved == query.output) { + query + } else { + Project(resolved, query) + } + } + def resolveOutputColumns( tableName: String, expected: Seq[Attribute], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index 04d6337376c39..15ece1226d8c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -232,5 +232,5 @@ case class ResolvedIdentifier( // A fake v2 catalog to hold temp views. object FakeSystemCatalog extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} - override def name(): String = "SYSTEM" + override def name(): String = "system" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala new file mode 100644 index 0000000000000..abe6cede0c550 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} +import org.apache.spark.sql.errors.DataTypeErrorsBase + +/** + * A thread-safe manager for temporary SQL variables (that live in the schema `SYSTEM.SESSION`), + * providing atomic operations to manage them, e.g. create, get, remove, etc. + * + * Note that, the variable name is always case-sensitive here, callers are responsible to format the + * variable name w.r.t. case-sensitive config. + */ +class TempVariableManager extends DataTypeErrorsBase { + + @GuardedBy("this") + private val variables = new mutable.HashMap[String, VariableDefinition] + + def create( + name: String, + defaultValueSQL: String, + initValue: Literal, + overrideIfExists: Boolean): Unit = synchronized { + if (!overrideIfExists && variables.contains(name)) { + throw new AnalysisException( + errorClass = "VARIABLE_ALREADY_EXISTS", + messageParameters = Map( + "variableName" -> toSQLId(Seq(SYSTEM_CATALOG_NAME, SESSION_NAMESPACE, name)))) + } + variables.put(name, VariableDefinition(defaultValueSQL, initValue)) + } + + def get(name: String): Option[VariableDefinition] = synchronized { + variables.get(name) + } + + def remove(name: String): Boolean = synchronized { + variables.remove(name).isDefined + } + + def clear(): Unit = synchronized { + variables.clear() + } +} + +case class VariableDefinition(defaultValueSQL: String, currentValue: Literal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 6b72500f3f672..4b04cfddbe8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -29,6 +29,7 @@ import org.json4s.JsonAST.{JArray, JString} import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedLeafNode} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN @@ -370,6 +371,26 @@ case class CatalogTable( } } + /** + * Return temporary variable names the current view was referred. should be empty if the + * CatalogTable is not a Temporary View or created by older versions of Spark(before 3.4.0). + */ + def viewReferredTempVariableNames: Seq[Seq[String]] = { + try { + properties.get(VIEW_REFERRED_TEMP_VARIABLE_NAMES).map { json => + parse(json).asInstanceOf[JArray].arr.map { namePartsJson => + namePartsJson.asInstanceOf[JArray].arr.map(_.asInstanceOf[JString].s) + } + }.getOrElse(Seq.empty) + } catch { + case e: Exception => + throw new AnalysisException( + errorClass = "INTERNAL_ERROR_METADATA_CATALOG.TEMP_VARIABLE_REFERENCE", + messageParameters = Map.empty, + cause = Some(e)) + } + } + /** Syntactic sugar to update a field in `storage`. */ def withNewStorage( locationUri: Option[URI] = storage.locationUri, @@ -474,6 +495,7 @@ object CatalogTable { val VIEW_REFERRED_TEMP_VIEW_NAMES = VIEW_PREFIX + "referredTempViewNames" val VIEW_REFERRED_TEMP_FUNCTION_NAMES = VIEW_PREFIX + "referredTempFunctionsNames" + val VIEW_REFERRED_TEMP_VARIABLE_NAMES = VIEW_PREFIX + "referredTempVariablesNames" val VIEW_STORING_ANALYZED_PLAN = VIEW_PREFIX + "storingAnalyzedPlan" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/VariableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/VariableReference.scala new file mode 100644 index 0000000000000..65240b97e972c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/VariableReference.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.VariableDefinition +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.quoteIfNeeded +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier} +import org.apache.spark.sql.types.DataType + +/** + * A resolved SQL variable, which contains all the information of it. + */ +case class VariableReference( + // When creating a temp view, we need to record all the temp SQL variables that are referenced + // by the temp view, so that they cal still be resolved as temp variables when reading the view + // again. Here we store the original name parts of the variables, as temp view keeps the + // original SQL string of the view query. + originalNameParts: Seq[String], + catalog: CatalogPlugin, + identifier: Identifier, + varDef: VariableDefinition, + // This flag will be false if the `VariableReference` is used to manage the variable, like + // setting a new value, where we shouldn't constant-fold it. + canFold: Boolean = true) + extends LeafExpression { + + override def dataType: DataType = varDef.currentValue.dataType + override def nullable: Boolean = varDef.currentValue.nullable + override def foldable: Boolean = canFold + + override def toString: String = { + val qualifiedName = (catalog.name +: identifier.namespace :+ identifier.name).map(quoteIfNeeded) + s"$prettyName(${qualifiedName.mkString(".")}=${varDef.currentValue.sql})" + } + + override def sql: String = toString + + // Delegate to the underlying `Literal` for actual execution. + override def eval(input: InternalRow): Any = varDef.currentValue.value + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + varDef.currentValue.doGenCode(ctx, ev) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0635e6a1b44fc..27b88197f86e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -450,7 +450,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { UpdateTable(aliasedTable, assignments, predicate) } - private def withAssignments(assignCtx: SqlBaseParser.AssignmentListContext): Seq[Assignment] = + protected def withAssignments(assignCtx: SqlBaseParser.AssignmentListContext): Seq[Assignment] = withOrigin(assignCtx) { assignCtx.assignment().asScala.map { assign => Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)), @@ -3154,7 +3154,9 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { ctx.asScala.headOption.map(visitLocationSpec) } - private def verifyAndGetExpression(exprCtx: ExpressionContext, place: String): String = { + private def getDefaultExpression( + exprCtx: ExpressionContext, + place: String): DefaultValueExpression = { // Make sure it can be converted to Catalyst expressions. val expr = expression(exprCtx) if (expr.containsPattern(PARAMETER)) { @@ -3166,7 +3168,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { // get the text from the underlying char stream instead. val start = exprCtx.getStart.getStartIndex val end = exprCtx.getStop.getStopIndex - exprCtx.getStart.getInputStream.getText(new Interval(start, end)) + val originalSQL = exprCtx.getStart.getInputStream.getText(new Interval(start, end)) + DefaultValueExpression(expr, originalSQL) } /** @@ -3174,7 +3177,16 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { */ override def visitDefaultExpression(ctx: DefaultExpressionContext): String = withOrigin(ctx) { - verifyAndGetExpression(ctx.expression(), "DEFAULT") + getDefaultExpression(ctx.expression(), "DEFAULT").originalSQL + } + + /** + * Create `DefaultValueExpression` for a SQL variable. + */ + override def visitVariableDefaultExpression( + ctx: VariableDefaultExpressionContext): DefaultValueExpression = + withOrigin(ctx) { + getDefaultExpression(ctx.expression(), "DEFAULT") } /** @@ -3182,7 +3194,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { */ override def visitGenerationExpression(ctx: GenerationExpressionContext): String = withOrigin(ctx) { - verifyAndGetExpression(ctx.expression(), "GENERATED") + getDefaultExpression(ctx.expression(), "GENERATED").originalSQL } /** @@ -5044,4 +5056,83 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { ctx: PosParameterLiteralContext): Expression = withOrigin(ctx) { PosParameter(ctx.QUESTION().getSymbol.getStartIndex) } + + /** + * Create a [[CreateVariable]] command. + * + * For example: + * {{{ + * DECLARE [OR REPLACE] [VARIABLE] [db_name.]variable_name + * [dataType] [defaultExpression]; + * }}} + * + * We will add CREATE VARIABLE for persisted variable definitions to this, hence the name. + */ + override def visitCreateVariable(ctx: CreateVariableContext): LogicalPlan = withOrigin(ctx) { + val dataTypeOpt = Option(ctx.dataType()).map(typedVisit[DataType]) + val defaultExpression = if (ctx.variableDefaultExpression() == null) { + if (dataTypeOpt.isEmpty) { + throw new ParseException( + errorClass = "INVALID_SQL_SYNTAX.VARIABLE_TYPE_OR_DEFAULT_REQUIRED", + messageParameters = Map.empty, + ctx.identifierReference) + } + DefaultValueExpression(Literal(null, dataTypeOpt.get), "null") + } else { + val default = visitVariableDefaultExpression(ctx.variableDefaultExpression()) + dataTypeOpt.map { dt => default.copy(child = Cast(default.child, dt)) }.getOrElse(default) + } + CreateVariable( + withIdentClause(ctx.identifierReference(), UnresolvedIdentifier(_)), + defaultExpression, + ctx.REPLACE() != null + ) + } + + /** + * Create a [[DropVariable]] command. + * + * For example: + * {{{ + * DROP TEMPORARY VARIABLE [IF EXISTS] variable; + * }}} + */ + override def visitDropVariable(ctx: DropVariableContext): LogicalPlan = withOrigin(ctx) { + DropVariable( + withIdentClause(ctx.identifierReference(), UnresolvedIdentifier(_)), + ctx.EXISTS() != null + ) + } + + /** + * Create a [[SetVariable]] command. + * + * For example: + * {{{ + * SET VARIABLE var1 = v1, var2 = v2, ... + * SET VARIABLE (var1, var2, ...) = (SELECT ...) + * }}} + */ + override def visitSetVariable(ctx: SetVariableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.query() != null) { + // The SET variable source is a query + val variables = ctx.multipartIdentifierList.multipartIdentifier.asScala.map { variableIdent => + val varName = visitMultipartIdentifier(variableIdent) + UnresolvedAttribute(varName) + }.toSeq + SetVariable(variables, visitQuery(ctx.query())) + } else { + // The SET variable source is list of expressions. + val (variables, values) = ctx.assignmentList().assignment().asScala.map { assign => + val varIdent = visitMultipartIdentifier(assign.key) + val varExpr = expression(assign.value) + val varNamedExpr = varExpr match { + case n: NamedExpression => n + case e => Alias(e, varIdent.last)() + } + (UnresolvedAttribute(varIdent), varNamedExpr) + }.toSeq.unzip + SetVariable(variables, Project(values, OneRowRelation())) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5c83da1a96aae..ac4098d4e4101 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1468,3 +1468,48 @@ case class TableSpec( TableSpec(properties, provider, options, newLocation, comment, serde, external) } } + +/** + * A fake expression which holds the default value expression and its original SQL text. + */ +case class DefaultValueExpression(child: Expression, originalSQL: String) + extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +/** + * The logical plan of the DECLARE [OR REPLACE] TEMPORARY VARIABLE command. + */ +case class CreateVariable( + name: LogicalPlan, + defaultExpr: DefaultValueExpression, + replace: Boolean) extends UnaryCommand with SupportsSubquery { + override def child: LogicalPlan = name + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(name = newChild) +} + +/** + * The logical plan of the DROP TEMPORARY VARIABLE command. + */ +case class DropVariable( + name: LogicalPlan, + ifExists: Boolean) extends UnaryCommand { + override def child: LogicalPlan = name + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(name = newChild) +} + +/** + * The logical plan of the SET VARIABLE command. + */ +case class SetVariable( + targetVariables: Seq[Expression], + sourceQuery: LogicalPlan) + extends UnaryCommand { + override def child: LogicalPlan = sourceQuery + override protected def withNewChildInternal(newChild: LogicalPlan): SetVariable = + copy(sourceQuery = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index caf679f3e7a7a..c170d34e1700c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -92,6 +92,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.ResolveOrderByAll" :: "org.apache.spark.sql.catalyst.analysis.ResolveRowLevelCommandAssignments" :: + "org.apache.spark.sql.catalyst.analysis.ResolveSetVariable" :: "org.apache.spark.sql.catalyst.analysis.ResolveTableSpec" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index cf9dd7fdf4767..16c387a82373b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -46,6 +46,9 @@ class CatalogManager( private val catalogs = mutable.HashMap.empty[String, CatalogPlugin] + // TODO: create a real SYSTEM catalog to host `TempVariableManager` under the SESSION namespace. + val tempVariableManager: TempVariableManager = new TempVariableManager + def catalog(name: String): CatalogPlugin = synchronized { if (name.equalsIgnoreCase(SESSION_CATALOG_NAME)) { v2SessionCatalog @@ -150,4 +153,6 @@ class CatalogManager( private[sql] object CatalogManager { val SESSION_CATALOG_NAME: String = "spark_catalog" + val SYSTEM_CATALOG_NAME = "system" + val SESSION_NAMESPACE = "session" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1e4f779e565af..5b3c3daa75b2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -629,10 +629,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def aliasNumberNotMatchColumnNumberError( columnSize: Int, outputSize: Int, t: TreeNode[_]): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1028", + errorClass = "ASSIGNMENT_ARITY_MISMATCH", messageParameters = Map( - "columnSize" -> columnSize.toString, - "outputSize" -> outputSize.toString), + "numExpr" -> columnSize.toString, + "numTarget" -> outputSize.toString), origin = t.origin) } @@ -712,6 +712,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("dt" -> dt.toString)) } + def unresolvedVariableError(name: Seq[String], searchPath: Seq[String]): Throwable = { + new AnalysisException( + errorClass = "UNRESOLVED_VARIABLE", + messageParameters = Map( + "variableName" -> toSQLId(name), + "searchPath" -> toSQLId(searchPath))) + } + def unresolvedRoutineError(name: FunctionIdentifier, searchPath: Seq[String]): Throwable = { new AnalysisException( errorClass = "UNRESOLVED_ROUTINE", @@ -1058,7 +1066,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def corruptedTableNameContextInCatalogError(numParts: Int, index: Int): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1084", + errorClass = "INTERNAL_ERROR_METADATA_CATALOG.TABLE_NAME_CONTEXT", messageParameters = Map( "numParts" -> numParts.toString, "index" -> index.toString)) @@ -1066,14 +1074,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def corruptedViewSQLConfigsInCatalogError(e: Exception): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1085", + errorClass = "INTERNAL_ERROR_METADATA_CATALOG.SQL_CONFIG", messageParameters = Map.empty, cause = Some(e)) } def corruptedViewQueryOutputColumnsInCatalogError(numCols: String, index: Int): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1086", + errorClass = "INTERNAL_ERROR_METADATA_CATALOG.VIEW_QUERY_COLUMN_ARITY", messageParameters = Map( "numCols" -> numCols, "index" -> index.toString)) @@ -1081,14 +1089,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def corruptedViewReferredTempViewInCatalogError(e: Exception): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1087", + errorClass = "INTERNAL_ERROR_METADATA_CATALOG.TEMP_VIEW_REFERENCE", messageParameters = Map.empty, cause = Some(e)) } def corruptedViewReferredTempFunctionsInCatalogError(e: Exception): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1088", + errorClass = "INTERNAL_ERROR_METADATA_CATALOG.TEMP_FUNCTION_REFERENCE", messageParameters = Map.empty, cause = Some(e)) } @@ -2895,6 +2903,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "tempObjName" -> toSQLId(funcName))) } + def notAllowedToCreatePermanentViewByReferencingTempVarError( + name: TableIdentifier, + varName: String): Throwable = { + new AnalysisException( + errorClass = "INVALID_TEMP_OBJ_REFERENCE", + messageParameters = Map( + "obj" -> "VIEW", + "objName" -> toSQLId(name.nameParts), + "tempObj" -> "VARIABLE", + "tempObjName" -> toSQLId(varName))) + } + def queryFromRawFilesIncludeCorruptRecordColumnError(): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1285", @@ -3300,7 +3320,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat // this is not allowed. def defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList(): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1339", + errorClass = "DEFAULT_PLACEMENT_INVALID", messageParameters = Map.empty) } @@ -3308,13 +3328,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat // DEFAULT column references and explicit column lists, since this is not implemented yet. def defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause(): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1340", + errorClass = "DEFAULT_PLACEMENT_INVALID", messageParameters = Map.empty) } def defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates(): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1343", + errorClass = "DEFAULT_PLACEMENT_INVALID", messageParameters = Map.empty) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index f3c5fb4bef3b5..1cc79a92c4ce2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2552,7 +2552,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def multipleRowSubqueryError(context: SQLQueryContext): Throwable = { + def multipleRowScalarSubqueryError(context: SQLQueryContext): Throwable = { new SparkException( errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala index 2c8bb8a6ac92c..74f7277f90ea6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SQLKeywordSuite.scala @@ -160,7 +160,9 @@ class SQLKeywordSuite extends SQLKeywordUtils { val documentedKeywords = keywordsInDoc.map(_.head).toSet if (allCandidateKeywords != documentedKeywords) { val undocumented = (allCandidateKeywords -- documentedKeywords).toSeq.sorted - fail("Some keywords are not documented: " + undocumented.mkString(", ")) + val overdocumented = (documentedKeywords -- allCandidateKeywords).toSeq.sorted + fail("Some keywords are not documented: " + undocumented.mkString(", ") + + " Extras: " + overdocumented.mkString(", ")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index e2e980073307d..6fd4ce7e08404 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -1137,13 +1137,13 @@ class AnalysisErrorSuite extends AnalysisTest { expectedMessageParameters = Map("sqlExpr" -> "\"scalarsubquery(c1)\"")) } - errorTest( + errorClassTest( "SPARK-34920: error code to error message", testRelation2.where($"bad_column" > 1).groupBy($"a")(UnresolvedAlias(max($"b"))), - "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name " + - "`bad_column` cannot be resolved. Did you mean one of the following? " + - "[`a`, `c`, `d`, `b`, `e`]" - :: Nil) + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + messageParameters = Map( + "objectName" -> "`bad_column`", + "proposal" -> "`a`, `c`, `d`, `b`, `e`")) errorClassTest( "SPARK-39783: backticks in error message for candidate column with dots", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 2bb3439da0160..539b32e7b3b24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -562,14 +562,16 @@ class AnalysisSuite extends AnalysisTest with Matchers { ).select(star()) } assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) - assertAnalysisError( + assertAnalysisErrorClass( tableColumnsWithAliases("col1" :: Nil), - Seq("Number of column aliases does not match number of columns. " + - "Number of column aliases: 1; number of columns: 4.")) - assertAnalysisError( + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "1", "numTarget" -> "4") + ) + assertAnalysisErrorClass( tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), - Seq("Number of column aliases does not match number of columns. " + - "Number of column aliases: 5; number of columns: 4.")) + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "5", "numTarget" -> "4") + ) } test("SPARK-20962 Support subquery column aliases in FROM clause") { @@ -582,14 +584,16 @@ class AnalysisSuite extends AnalysisTest with Matchers { ).select(star()) } assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) - assertAnalysisError( + assertAnalysisErrorClass( tableColumnsWithAliases("col1" :: Nil), - Seq("Number of column aliases does not match number of columns. " + - "Number of column aliases: 1; number of columns: 4.")) - assertAnalysisError( + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "1", "numTarget" -> "4") + ) + assertAnalysisErrorClass( tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), - Seq("Number of column aliases does not match number of columns. " + - "Number of column aliases: 5; number of columns: 4.")) + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "5", "numTarget" -> "4") + ) } test("SPARK-20963 Support aliases for join relations in FROM clause") { @@ -604,14 +608,16 @@ class AnalysisSuite extends AnalysisTest with Matchers { ).select(star()) } assertAnalysisSuccess(joinRelationWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) - assertAnalysisError( + assertAnalysisErrorClass( joinRelationWithAliases("col1" :: Nil), - Seq("Number of column aliases does not match number of columns. " + - "Number of column aliases: 1; number of columns: 4.")) - assertAnalysisError( + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "1", "numTarget" -> "4") + ) + assertAnalysisErrorClass( joinRelationWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), - Seq("Number of column aliases does not match number of columns. " + - "Number of column aliases: 5; number of columns: 4.")) + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "5", "numTarget" -> "4") + ) } test("SPARK-22614 RepartitionByExpression partitioning") { @@ -753,9 +759,11 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("CTE with non-matching column alias") { - assertAnalysisError(parsePlan("WITH t(x, y) AS (SELECT 1) SELECT * FROM t WHERE x = 1"), - Seq("Number of column aliases does not match number of columns. Number of column aliases: " + - "2; number of columns: 1.")) + assertAnalysisErrorClass(parsePlan("WITH t(x, y) AS (SELECT 1) SELECT * FROM t WHERE x = 1"), + "ASSIGNMENT_ARITY_MISMATCH", + Map("numExpr" -> "2", "numTarget" -> "1"), + Array(ExpectedContext("t(x, y) AS (SELECT 1)", 5, 25)) + ) } test("SPARK-28251: Insert into non-existing table error message is user friendly") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6f7dd852cfed9..da3159319f98e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.adaptive.LogicalQueryStageStrategy +import org.apache.spark.sql.execution.command.v2.V2CommandStrategy import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy @@ -36,6 +37,7 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen LogicalQueryStageStrategy :: PythonEvals :: new DataSourceV2Strategy(session) :: + V2CommandStrategy :: FileSourceStrategy :: DataSourceStrategy :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 72502a7626b08..672417f1adbf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.errors.QueryCompilationErrors.toSQLId import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.{StringType, StructField, StructType} - /** * Command that runs * {{{ @@ -91,6 +92,23 @@ case class SetCommand(kv: Option[(String, Option[String])]) // Configures a single property. case Some((key, Some(value))) => val runFunc = (sparkSession: SparkSession) => { + /** + * Be nice and detect if the key matches a SQL variable. + * If it does give a meaningful error pointing the user to SET VARIABLE + */ + val varName = try { + sparkSession.sessionState.sqlParser.parseMultipartIdentifier(key) + } catch { + case _: ParseException => + Seq() + } + if (varName.nonEmpty && varName.length <= 3) { + if (sparkSession.sessionState.analyzer.lookupVariable(varName).isDefined) { + throw new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + messageParameters = Map("variableName" -> toSQLId(varName))) + } + } if (sparkSession.conf.get(CATALOG_IMPLEMENTATION.key).equals("hive") && key.startsWith("hive.")) { logWarning(s"'SET $key=$value' might not work, since Spark doesn't support changing " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala new file mode 100644 index 0000000000000..0ed1c104edb92 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import java.util.Locale + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionsEvaluator, Literal} +import org.apache.spark.sql.catalyst.plans.logical.DefaultValueExpression +import org.apache.spark.sql.execution.datasources.v2.LeafV2CommandExec + +/** + * Physical plan node for creating a variable. + */ +case class CreateVariableExec(name: String, defaultExpr: DefaultValueExpression, replace: Boolean) + extends LeafV2CommandExec with ExpressionsEvaluator { + + override protected def run(): Seq[InternalRow] = { + val variableManager = session.sessionState.catalogManager.tempVariableManager + val exprs = prepareExpressions(Seq(defaultExpr.child), subExprEliminationEnabled = false) + initializeExprs(exprs, 0) + val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) + val normalizedName = if (session.sessionState.conf.caseSensitiveAnalysis) { + name + } else { + name.toLowerCase(Locale.ROOT) + } + variableManager.create( + normalizedName, defaultExpr.originalSQL, initValue, replace) + Nil + } + + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala new file mode 100644 index 0000000000000..22538076879f7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/DropVariableExec.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import java.util.Locale + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.DataTypeErrorsBase +import org.apache.spark.sql.execution.datasources.v2.LeafV2CommandExec + +/** + * Physical plan node for dropping a variable. + */ +case class DropVariableExec(name: String, ifExists: Boolean) extends LeafV2CommandExec + with DataTypeErrorsBase { + + override protected def run(): Seq[InternalRow] = { + val variableManager = session.sessionState.catalogManager.tempVariableManager + val normalizedName = if (session.sessionState.conf.caseSensitiveAnalysis) { + name + } else { + name.toLowerCase(Locale.ROOT) + } + if (!variableManager.remove(normalizedName)) { + // The variable does not exist + if (!ifExists) { + throw new AnalysisException( + errorClass = "VARIABLE_NOT_FOUND", + Map("variableName" -> toSQLId( + Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE, name)))) + } + } + Nil + } + + override def output: Seq[Attribute] = Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala new file mode 100644 index 0000000000000..a5d90b4d154ce --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.TempVariableManager +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, VariableReference} +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec + +/** + * Physical plan node for setting a variable. + */ +case class SetVariableExec(variables: Seq[VariableReference], query: SparkPlan) + extends V2CommandExec with UnaryLike[SparkPlan] { + + override protected def run(): Seq[InternalRow] = { + val variableManager = session.sessionState.catalogManager.tempVariableManager + val values = query.executeCollect() + if (values.length == 0) { + variables.foreach { v => + createVariable(variableManager, v, null) + } + } else if (values.length > 1) { + throw new SparkException( + errorClass = "ROW_SUBQUERY_TOO_MANY_ROWS", + messageParameters = Map.empty, + cause = null) + } else { + val row = values(0) + variables.zipWithIndex.foreach { case (v, index) => + val value = row.get(index, v.dataType) + createVariable(variableManager, v, value) + } + } + Seq.empty + } + + private def createVariable( + variableManager: TempVariableManager, + variable: VariableReference, + value: Any): Unit = { + variableManager.create( + variable.identifier.name, + variable.varDef.defaultValueSQL, + Literal(value, variable.dataType), + overrideIfExists = true) + } + + override def output: Seq[Attribute] = Seq.empty + override def child: SparkPlan = query + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { + copy(query = newChild) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala new file mode 100644 index 0000000000000..ebc2e83e9c5fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier +import org.apache.spark.sql.catalyst.expressions.VariableReference +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.SparkPlan + +object V2CommandStrategy extends Strategy { + + // TODO: move v2 commands to here which are not data source v2 related. + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case CreateVariable(ident: ResolvedIdentifier, defaultExpr, replace) => + CreateVariableExec(ident.identifier.name, defaultExpr, replace) :: Nil + + case DropVariable(ident: ResolvedIdentifier, ifExists) => + DropVariableExec(ident.identifier.name, ifExists) :: Nil + + case SetVariable(variables, query) => + SetVariableExec(variables.map(_.asInstanceOf[VariableReference]), planLater(query)) :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 8ac982b9bdd91..7b95d34e6b6e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, GlobalTempView, LocalTempView, ViewType} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, TemporaryViewRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression, VariableReference} import org.apache.spark.sql.catalyst.plans.logical.{AnalysisOnlyCommand, CTEInChildren, CTERelationDef, LogicalPlan, Project, View, WithCTE} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper @@ -439,14 +439,19 @@ object ViewHelper extends SQLConfHelper with Logging { * Convert the temporary object names to `properties`. */ private def referredTempNamesToProps( - viewNames: Seq[Seq[String]], functionsNames: Seq[String]): Map[String, String] = { + viewNames: Seq[Seq[String]], + functionsNames: Seq[String], + variablesNames: Seq[Seq[String]]): Map[String, String] = { val viewNamesJson = JArray(viewNames.map(nameParts => JArray(nameParts.map(JString).toList)).toList) val functionsNamesJson = JArray(functionsNames.map(JString).toList) + val variablesNamesJson = + JArray(variablesNames.map(nameParts => JArray(nameParts.map(JString).toList)).toList) val props = new mutable.HashMap[String, String] props.put(VIEW_REFERRED_TEMP_VIEW_NAMES, compact(render(viewNamesJson))) props.put(VIEW_REFERRED_TEMP_FUNCTION_NAMES, compact(render(functionsNamesJson))) + props.put(VIEW_REFERRED_TEMP_VARIABLE_NAMES, compact(render(variablesNamesJson))) props.toMap } @@ -458,7 +463,8 @@ object ViewHelper extends SQLConfHelper with Logging { // while `CatalogTable` should be serializable. properties.filterNot { case (key, _) => key.startsWith(VIEW_REFERRED_TEMP_VIEW_NAMES) || - key.startsWith(VIEW_REFERRED_TEMP_FUNCTION_NAMES) + key.startsWith(VIEW_REFERRED_TEMP_FUNCTION_NAMES) || + key.startsWith(VIEW_REFERRED_TEMP_VARIABLE_NAMES) } } @@ -481,7 +487,8 @@ object ViewHelper extends SQLConfHelper with Logging { analyzedPlan: LogicalPlan, fieldNames: Array[String], tempViewNames: Seq[Seq[String]] = Seq.empty, - tempFunctionNames: Seq[String] = Seq.empty): Map[String, String] = { + tempFunctionNames: Seq[String] = Seq.empty, + tempVariableNames: Seq[Seq[String]] = Seq.empty): Map[String, String] = { // for createViewCommand queryOutput may be different from fieldNames val queryOutput = analyzedPlan.schema.fieldNames @@ -497,7 +504,7 @@ object ViewHelper extends SQLConfHelper with Logging { catalogAndNamespaceToProps(manager.currentCatalog.name, manager.currentNamespace) ++ sqlConfigsToProps(conf) ++ generateQueryColumnNames(queryOutput) ++ - referredTempNamesToProps(tempViewNames, tempFunctionNames) + referredTempNamesToProps(tempViewNames, tempFunctionNames, tempVariableNames) } /** @@ -579,6 +586,11 @@ object ViewHelper extends SQLConfHelper with Logging { throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempFuncError( name, funcName) } + val tempVars = collectTemporaryVariables(child) + tempVars.foreach { nameParts => + throw QueryCompilationErrors.notAllowedToCreatePermanentViewByReferencingTempVarError( + name, nameParts.quoted) + } } } @@ -600,6 +612,22 @@ object ViewHelper extends SQLConfHelper with Logging { collectTempViews(child) } + /** + * Collect all temporary SQL variables and return the identifiers separately. + */ + private def collectTemporaryVariables(child: LogicalPlan): Seq[Seq[String]] = { + def collectTempVars(child: LogicalPlan): Seq[Seq[String]] = { + child.flatMap { plan => + plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => collectTempVars(e.plan) + case r: VariableReference => Seq(r.originalNameParts) + case _ => Seq.empty + }) + }.distinct + } + collectTempVars(child) + } + /** * Returns a [[TemporaryViewRelation]] that contains information about a temporary view * to create, given an analyzed plan of the view. If a temp view is to be replaced and it is @@ -684,10 +712,12 @@ object ViewHelper extends SQLConfHelper with Logging { val catalog = session.sessionState.catalog val tempViews = collectTemporaryViews(analyzedPlan) + val tempVariables = collectTemporaryVariables(analyzedPlan) // TBLPROPERTIES is not allowed for temporary view, so we don't use it for // generating temporary view properties val newProperties = generateViewProperties( - Map.empty, session, analyzedPlan, viewSchema.fieldNames, tempViews, tempFunctions) + Map.empty, session, analyzedPlan, viewSchema.fieldNames, tempViews, + tempFunctions, tempVariables) CatalogTable( identifier = viewName, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 709402571cad4..2a28f6848aaa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -81,7 +81,7 @@ case class ScalarSubquery( def updateResult(): Unit = { val rows = plan.executeCollect() if (rows.length > 1) { - throw QueryExecutionErrors.multipleRowSubqueryError(getContextOrNull()) + throw QueryExecutionErrors.multipleRowScalarSubqueryError(getContextOrNull()) } if (rows.length == 1) { assert(rows(0).numFields == 1, diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out new file mode 100644 index 0000000000000..45bfbf69db325 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -0,0 +1,2112 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SET spark.sql.ansi.enabled = true +-- !query analysis +SetCommand (spark.sql.ansi.enabled,Some(true)) + + +-- !query +DECLARE title STRING +-- !query analysis +CreateVariable defaultvalueexpression(null, null), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.title + + +-- !query +SET VARIABLE title = '-- Basic sanity --' +-- !query analysis +SetVariable [variablereference(system.session.title=CAST(NULL AS STRING))] ++- Project [-- Basic sanity -- AS title#x] + +- OneRowRelation + + +-- !query +DECLARE var1 INT = 5 +-- !query analysis +CreateVariable defaultvalueexpression(cast(5 as int), 5), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=5) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 6 +-- !query analysis +SetVariable [variablereference(system.session.var1=5)] ++- Project [6 AS var1#x] + +- OneRowRelation + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=6) AS var1#x] ++- OneRowRelation + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'Create Variable - Success Cases' +-- !query analysis +SetVariable [variablereference(system.session.title='-- Basic sanity --')] ++- Project [Create Variable - Success Cases AS title#x] + +- OneRowRelation + + +-- !query +DECLARE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'Expect: INT, NULL', typeof(var1), var1 +-- !query analysis +Project [Expect: INT, NULL AS Expect: INT, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS INT))) AS typeof(variablereference(system.session.var1=CAST(NULL AS INT)) AS var1)#x, variablereference(system.session.var1=CAST(NULL AS INT)) AS var1#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 DOUBLE +-- !query analysis +CreateVariable defaultvalueexpression(null, null), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1 +-- !query analysis +Project [Expect: DOUBLE, NULL AS Expect: DOUBLE, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE))) AS typeof(variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1)#x, variablereference(system.session.var1=CAST(NULL AS DOUBLE)) AS var1#x] ++- OneRowRelation + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE OR REPLACE VARIABLE var1 TIMESTAMP +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1 +-- !query analysis +Project [Expect: TIMESTAMP, NULL AS Expect: TIMESTAMP, NULL#x, typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP))) AS typeof(variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1)#x, variablereference(system.session.var1=CAST(NULL AS TIMESTAMP)) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE title = 'Create Variable - Failure Cases' +-- !query analysis +SetVariable [variablereference(system.session.title='Create Variable - Success Cases')] ++- Project [Create Variable - Failure Cases AS title#x] + +- OneRowRelation + + +-- !query +DECLARE VARIABLE IF NOT EXISTS var1 INT +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'EXISTS'", + "hint" : "" + } +} + + +-- !query +DROP TEMPORARY VARIABLE IF EXISTS var1 +-- !query analysis +DropVariable true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'Drop Variable' +-- !query analysis +SetVariable [variablereference(system.session.title='Create Variable - Failure Cases')] ++- Project [Drop Variable AS title#x] + +- OneRowRelation + + +-- !query +DECLARE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=CAST(NULL AS INT)) AS var1#x] ++- OneRowRelation + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT var1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 11, + "fragment" : "var1" + } ] +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE IF EXISTS var1 +-- !query analysis +DropVariable true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP VARIABLE var1 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'VARIABLE'", + "hint" : "" + } +} + + +-- !query +DROP VARIABLE system.session.var1 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'VARIABLE'", + "hint" : "" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'Test qualifiers - success' +-- !query analysis +SetVariable [variablereference(system.session.title='Drop Variable')] ++- Project [Test qualifiers - success AS title#x] + +- OneRowRelation + + +-- !query +DECLARE VARIABLE var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query analysis +Project [1 AS Expected#x, variablereference(system.session.var1=1) AS Unqualified#x, variablereference(system.session.var1=1) AS SchemaQualified#x, variablereference(system.session.var1=1) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 2 +-- !query analysis +SetVariable [variablereference(system.session.var1=1)] ++- Project [2 AS var1#x] + +- OneRowRelation + + +-- !query +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query analysis +Project [2 AS Expected#x, variablereference(system.session.var1=2) AS Unqualified#x, variablereference(system.session.var1=2) AS SchemaQualified#x, variablereference(system.session.var1=2) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE session.var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query analysis +Project [1 AS Expected#x, variablereference(system.session.var1=1) AS Unqualified#x, variablereference(system.session.var1=1) AS SchemaQualified#x, variablereference(system.session.var1=1) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +SET VARIABLE session.var1 = 2 +-- !query analysis +SetVariable [variablereference(system.session.var1=1)] ++- Project [2 AS var1#x] + +- OneRowRelation + + +-- !query +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query analysis +Project [2 AS Expected#x, variablereference(system.session.var1=2) AS Unqualified#x, variablereference(system.session.var1=2) AS SchemaQualified#x, variablereference(system.session.var1=2) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE system.session.var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query analysis +Project [1 AS Expected#x, variablereference(system.session.var1=1) AS Unqualified#x, variablereference(system.session.var1=1) AS SchemaQualified#x, variablereference(system.session.var1=1) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +SET VARIABLE system.session.var1 = 2 +-- !query analysis +SetVariable [variablereference(system.session.var1=1)] ++- Project [2 AS var1#x] + +- OneRowRelation + + +-- !query +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query analysis +Project [2 AS Expected#x, variablereference(system.session.var1=2) AS Unqualified#x, variablereference(system.session.var1=2) AS SchemaQualified#x, variablereference(system.session.var1=2) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE sySteM.sEssIon.vAr1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.vAr1 + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, sessIon.Var1 AS SchemaQualified, System.sessiOn.var1 AS fullyQualified +-- !query analysis +Project [1 AS Expected#x, variablereference(system.session.var1=1) AS Unqualified#x, variablereference(system.session.var1=1) AS SchemaQualified#x, variablereference(system.session.var1=1) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +SET VARIABLE sYstem.sesSiOn.vaR1 = 2 +-- !query analysis +SetVariable [variablereference(system.session.var1=1)] ++- Project [2 AS vaR1#x] + +- OneRowRelation + + +-- !query +SELECT 2 as Expected, VAR1 as Unqualified, SESSION.VAR1 AS SchemaQualified, SYSTEM.SESSION.VAR1 AS fullyQualified +-- !query analysis +Project [2 AS Expected#x, variablereference(system.session.var1=2) AS Unqualified#x, variablereference(system.session.var1=2) AS SchemaQualified#x, variablereference(system.session.var1=2) AS fullyQualified#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE session.var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE system.session.var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE sysTem.sesSion.vAr1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.vAr1 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +SET VARIABLE title = 'Test qualifiers - fail' +-- !query analysis +SetVariable [variablereference(system.session.title='Test qualifiers - success')] ++- Project [Test qualifiers - fail AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE builtin.var1 INT +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`system`.`session`", + "variableName" : "`builtin`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE system.sesion.var1 INT +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`system`.`session`", + "variableName" : "`system`.`sesion`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE sys.session.var1 INT +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`system`.`session`", + "variableName" : "`sys`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query analysis +CreateVariable defaultvalueexpression(null, null), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT var +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`var`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 10, + "fragment" : "var" + } ] +} + + +-- !query +SELECT ses.var1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`ses`.`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 15, + "fragment" : "ses.var1" + } ] +} + + +-- !query +SELECT b.sesson.var1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`b`.`sesson`.`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "b.sesson.var1" + } ] +} + + +-- !query +SELECT builtn.session.var1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`builtn`.`session`.`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "builtn.session.var1" + } ] +} + + +-- !query +SET VARIABLE ses.var1 = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`SYSTEM`.`SESSION`", + "variableName" : "`ses`.`var1`" + } +} + + +-- !query +SET VARIABLE builtn.session.var1 = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`SYSTEM`.`SESSION`", + "variableName" : "`builtn`.`session`.`var1`" + } +} + + +-- !query +SET VARIABLE title = 'Test DEFAULT on create - success' +-- !query analysis +SetVariable [variablereference(system.session.title='Test qualifiers - fail')] ++- Project [Test DEFAULT on create - success AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 1 AS Expected, var1 AS result +-- !query analysis +Project [1 AS Expected#x, variablereference(system.session.var1=1) AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 DOUBLE DEFAULT 1 + RAND(5) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT true AS Expected, var1 >= 1 AS result +-- !query analysis +Project [true AS Expected#x, (variablereference(system.session.var1=1.023906964275029D) >= cast(1 as double)) AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 = 'Hello' +-- !query analysis +CreateVariable defaultvalueexpression(Hello, 'Hello'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'STRING, Hello' AS Expected, typeof(var1) AS type, var1 AS result +-- !query analysis +Project [STRING, Hello AS Expected#x, typeof(variablereference(system.session.var1='Hello')) AS type#x, variablereference(system.session.var1='Hello') AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 DEFAULT NULL +-- !query analysis +CreateVariable defaultvalueexpression(null, NULL), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'VOID, NULL' AS Expected, typeof(var1) AS type, var1 AS result +-- !query analysis +Project [VOID, NULL AS Expected#x, typeof(variablereference(system.session.var1=NULL)) AS type#x, variablereference(system.session.var1=NULL) AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE INT DEFAULT 5.0 +-- !query analysis +CreateVariable defaultvalueexpression(5.0, 5.0), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.INT + + +-- !query +SELECT 'INT, 5' AS Expected, typeof(var1) AS type, var1 AS result +-- !query analysis +Project [INT, 5 AS Expected#x, typeof(variablereference(system.session.var1=NULL)) AS type#x, variablereference(system.session.var1=NULL) AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 MAP DEFAULT MAP('Hello', 5.1, 'World', -7.1E10) +-- !query analysis +CreateVariable defaultvalueexpression(cast(map(Hello, cast(5.1 as double), World, -7.1E10) as map), MAP('Hello', 5.1, 'World', -7.1E10)), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'MAP, [Hello -> 5.1, World -> -7E10]' AS Expected, typeof(var1) AS type, var1 AS result +-- !query analysis +Project [MAP, [Hello -> 5.1, World -> -7E10] AS Expected#x, typeof(variablereference(system.session.var1=MAP('Hello', 5.1D, 'World', -7.1E10D))) AS type#x, variablereference(system.session.var1=MAP('Hello', 5.1D, 'World', -7.1E10D)) AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT NULL +-- !query analysis +CreateVariable defaultvalueexpression(cast(null as int), NULL), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'NULL' AS Expected, var1 AS result +-- !query analysis +Project [NULL AS Expected#x, variablereference(system.session.var1=CAST(NULL AS INT)) AS result#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT CURRENT_DATABASE() +-- !query analysis +CreateVariable defaultvalueexpression(cast(current_database() as string), CURRENT_DATABASE()), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT 'true' AS Expected, length(var1) > 0 AS result +-- !query analysis +Project [true AS Expected#x, (length(variablereference(system.session.var1='default')) > 0) AS result#x] ++- OneRowRelation + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE var1 +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "INVALID_SQL_SYNTAX.VARIABLE_TYPE_OR_DEFAULT_REQUIRED", + "sqlState" : "42000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 12, + "fragment" : "DECLARE var1" + } ] +} + + +-- !query +SET VARIABLE title = 'Test DEFAULT on create - failures' +-- !query analysis +SetVariable [variablereference(system.session.title='Test DEFAULT on create - success')] ++- Project [Test DEFAULT on create - failures AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT (SELECT c1 FROM VALUES(1) AS T(c1)) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + "messageParameters" : { + "colName" : "`system`.`session`.`var1`", + "defaultValue" : "(SELECT c1 FROM VALUES(1) AS T(c1))", + "statement" : "CRETE VARIABLE" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 'hello' +-- !query analysis +org.apache.spark.SparkNumberFormatException +{ + "errorClass" : "CAST_INVALID_INPUT", + "sqlState" : "22018", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "expression" : "'hello'", + "sourceType" : "\"STRING\"", + "targetType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 'hello'" + } ] +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 / 0 +-- !query analysis +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 46, + "stopIndex" : 50, + "fragment" : "1 / 0" + } ] +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 SMALLINT DEFAULT 100000 +-- !query analysis +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "CAST_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "sourceType" : "\"INT\"", + "targetType" : "\"SMALLINT\"", + "value" : "100000" + } +} + + +-- !query +SET VARIABLE title = 'SET VARIABLE - single target' +-- !query analysis +SetVariable [variablereference(system.session.title='Test DEFAULT on create - failures')] ++- Project [SET VARIABLE - single target AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query analysis +CreateVariable defaultvalueexpression(cast(5 as int), 5), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE var1 = 7 +-- !query analysis +SetVariable [variablereference(system.session.var1=5)] ++- Project [7 AS var1#x] + +- OneRowRelation + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=7) AS var1#x] ++- OneRowRelation + + +-- !query +SET VAR var1 = 8 +-- !query analysis +SetVariable [variablereference(system.session.var1=7)] ++- Project [8 AS var1#x] + +- OneRowRelation + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=8) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1) AS T(c1)) +-- !query analysis +SetVariable [variablereference(system.session.var1=8)] ++- Project [scalar-subquery#x [] AS var1#x] + : +- Project [c1#x] + : +- SubqueryAlias T + : +- LocalRelation [c1#x] + +- OneRowRelation + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=1) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1) AS T(c1) WHERE 1=0) +-- !query analysis +SetVariable [variablereference(system.session.var1=1)] ++- Project [scalar-subquery#x [] AS var1#x] + : +- Project [c1#x] + : +- Filter (1 = 0) + : +- SubqueryAlias T + : +- LocalRelation [c1#x] + +- OneRowRelation + + +-- !query +SELECT var1 AS `null` +-- !query analysis +Project [variablereference(system.session.var1=CAST(NULL AS INT)) AS null#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1.0) AS T(c1)) +-- !query analysis +SetVariable [variablereference(system.session.var1=CAST(NULL AS INT))] ++- Project [cast(var1#x as int) AS var1#x] + +- Project [scalar-subquery#x [] AS var1#x] + : +- Project [c1#x] + : +- SubqueryAlias T + : +- LocalRelation [c1#x] + +- OneRowRelation + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=1) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1.0E10) AS T(c1)) +-- !query analysis +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "CAST_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "sourceType" : "\"DOUBLE\"", + "targetType" : "\"INT\"", + "value" : "1.0E10D" + } +} + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=1) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1), (2) AS T(c1)) +-- !query analysis +org.apache.spark.SparkException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 60, + "fragment" : "(SELECT c1 FROM VALUES(1), (2) AS T(c1))" + } ] +} + + +-- !query +SET VARIABLE var1 = (SELECT c1, c1 FROM VALUES(1), (2) AS T(c1)) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_SUBQUERY_EXPRESSION.SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN", + "sqlState" : "42823", + "messageParameters" : { + "number" : "2" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 64, + "fragment" : "(SELECT c1, c1 FROM VALUES(1), (2) AS T(c1))" + } ] +} + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES('hello') AS T(c1)) +-- !query analysis +org.apache.spark.SparkNumberFormatException +{ + "errorClass" : "CAST_INVALID_INPUT", + "sqlState" : "22018", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "expression" : "'hello'", + "sourceType" : "\"STRING\"", + "targetType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 61, + "fragment" : "SET VARIABLE var1 = (SELECT c1 FROM VALUES('hello') AS T(c1))" + } ] +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query analysis +CreateVariable defaultvalueexpression(cast(5 as int), 5), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE var1 = var1 + 1 +-- !query analysis +SetVariable [variablereference(system.session.var1=5)] ++- Project [(variablereference(system.session.var1=5) + 1) AS var1#x] + +- OneRowRelation + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=6) AS var1#x] ++- OneRowRelation + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'SET VARIABLE - comma separated target' +-- !query analysis +SetVariable [variablereference(system.session.title='SET VARIABLE - single target')] ++- Project [SET VARIABLE - comma separated target AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query analysis +CreateVariable defaultvalueexpression(cast(5 as int), 5), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello' +-- !query analysis +CreateVariable defaultvalueexpression(cast(hello as string), 'hello'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2 +-- !query analysis +CreateVariable defaultvalueexpression(cast(2 as double), 2), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE var1 = 6, var2 = 'world', var3 = pi() +-- !query analysis +SetVariable [variablereference(system.session.var1=5), variablereference(system.session.var2='hello'), variablereference(system.session.var3=2.0D)] ++- Project [6 AS var1#x, world AS var2#x, PI() AS var3#x] + +- OneRowRelation + + +-- !query +SELECT var1 AS `6`, var2 AS `world` , var3 as `3.14...` +-- !query analysis +Project [variablereference(system.session.var1=6) AS 6#x, variablereference(system.session.var2='world') AS world#x, variablereference(system.session.var3=3.141592653589793D) AS 3.14...#x] ++- OneRowRelation + + +-- !query +SET VAR var1 = 7, var2 = 'universe', var3 = -1 +-- !query analysis +SetVariable [variablereference(system.session.var1=6), variablereference(system.session.var2='world'), variablereference(system.session.var3=3.141592653589793D)] ++- Project [var1#x, var2#x, cast(var3#x as double) AS var3#x] + +- Project [7 AS var1#x, universe AS var2#x, -1 AS var3#x] + +- OneRowRelation + + +-- !query +SELECT var1 AS `7`, var2 AS `universe` , var3 as `-1` +-- !query analysis +Project [variablereference(system.session.var1=7) AS 7#x, variablereference(system.session.var2='universe') AS universe#x, variablereference(system.session.var3=-1.0D) AS -1#x] ++- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query analysis +CreateVariable defaultvalueexpression(cast(5 as int), 5), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello' +-- !query analysis +CreateVariable defaultvalueexpression(cast(hello as string), 'hello'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2 +-- !query analysis +CreateVariable defaultvalueexpression(cast(2 as double), 2), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE var1 = var3, var2 = ascii(var1), var3 = var1 +-- !query analysis +SetVariable [variablereference(system.session.var1=5), variablereference(system.session.var2='hello'), variablereference(system.session.var3=2.0D)] ++- Project [cast(var3#x as int) AS var1#x, cast(var2#x as string) AS var2#x, cast(var1#x as double) AS var3#x] + +- Project [variablereference(system.session.var3=2.0D) AS var3#x, ascii(cast(variablereference(system.session.var1=5) as string)) AS var2#x, variablereference(system.session.var1=5) AS var1#x] + +- OneRowRelation + + +-- !query +SELECT var1 AS `2`, var2 AS `104`, var3 AS `5` +-- !query analysis +Project [variablereference(system.session.var1=2) AS 2#x, variablereference(system.session.var2='53') AS 104#x, variablereference(system.session.var3=5.0D) AS 5#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = var3, var2 = INTERVAL'5' HOUR, var3 = var1 +-- !query analysis +SetVariable [variablereference(system.session.var1=2), variablereference(system.session.var2='53'), variablereference(system.session.var3=5.0D)] ++- Project [cast(var3#x as int) AS var1#x, cast(var2#x as string) AS var2#x, cast(var1#x as double) AS var3#x] + +- Project [variablereference(system.session.var3=5.0D) AS var3#x, INTERVAL '05' HOUR AS var2#x, variablereference(system.session.var1=2) AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE var1 = 1, var2 = 0, vAr1 = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ASSIGNMENTS", + "sqlState" : "42701", + "messageParameters" : { + "nameList" : "`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var2 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DROP TEMPORARY VARIABLE var3 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE title = 'SET VARIABLE - row assignment' +-- !query analysis +SetVariable [variablereference(system.session.title='SET VARIABLE - comma separated target')] ++- Project [SET VARIABLE - row assignment AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query analysis +CreateVariable defaultvalueexpression(cast(5 as int), 5), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello' +-- !query analysis +CreateVariable defaultvalueexpression(cast(hello as string), 'hello'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2 +-- !query analysis +CreateVariable defaultvalueexpression(cast(2 as double), 2), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE (var1) = (SELECT c1 FROM VALUES(1) AS T(c1)) +-- !query analysis +SetVariable [variablereference(system.session.var1=5)] ++- Project [c1#x] + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=1) AS var1#x] ++- OneRowRelation + + +-- !query +SET VAR (var1) = (SELECT c1 FROM VALUES(2) AS T(c1)) +-- !query analysis +SetVariable [variablereference(system.session.var1=1)] ++- Project [c1#x] + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=2) AS var1#x] ++- OneRowRelation + + +-- !query +SET VARIABLE (var1, var2) = (SELECT c1, c2 FROM VALUES(10, 11) AS T(c1, c2)) +-- !query analysis +SetVariable [variablereference(system.session.var1=2), variablereference(system.session.var2='hello')] ++- Project [c1#x, cast(c2#x as string) AS var2#x] + +- Project [c1#x, c2#x] + +- SubqueryAlias T + +- LocalRelation [c1#x, c2#x] + + +-- !query +SELECT var1 AS `10`, var2 AS `11` +-- !query analysis +Project [variablereference(system.session.var1=10) AS 10#x, variablereference(system.session.var2='11') AS 11#x] ++- OneRowRelation + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query analysis +SetVariable [variablereference(system.session.var1=10), variablereference(system.session.var2='11'), variablereference(system.session.var3=2.0D)] ++- Project [c1#x, cast(c2#x as string) AS var2#x, cast(c3#x as double) AS var3#x] + +- Project [c1#x, c2#x, c3#x] + +- SubqueryAlias T + +- LocalRelation [c1#x, c2#x, c3#x] + + +-- !query +SELECT var1 AS `100`, var2 AS `110`, var3 AS `120` +-- !query analysis +Project [variablereference(system.session.var1=100) AS 100#x, variablereference(system.session.var2='110') AS 110#x, variablereference(system.session.var3=120.0D) AS 120#x] ++- OneRowRelation + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120) AS T(c1, c2, c3) WHERE 1 = 0) +-- !query analysis +SetVariable [variablereference(system.session.var1=100), variablereference(system.session.var2='110'), variablereference(system.session.var3=120.0D)] ++- Project [c1#x, cast(c2#x as string) AS var2#x, cast(c3#x as double) AS var3#x] + +- Project [c1#x, c2#x, c3#x] + +- Filter (1 = 0) + +- SubqueryAlias T + +- LocalRelation [c1#x, c2#x, c3#x] + + +-- !query +SELECT var1 AS `NULL`, var2 AS `NULL`, var3 AS `NULL` +-- !query analysis +Project [variablereference(system.session.var1=CAST(NULL AS INT)) AS NULL#x, variablereference(system.session.var2=CAST(NULL AS STRING)) AS NULL#x, variablereference(system.session.var3=CAST(NULL AS DOUBLE)) AS NULL#x] ++- OneRowRelation + + +-- !query +SET VARIABLE () = (SELECT 1) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "INVALID_SET_SYNTAX", + "sqlState" : "42000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 28, + "fragment" : "SET VARIABLE () = (SELECT 1)" + } ] +} + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120), (-100, -110, -120) AS T(c1, c2, c3)) +-- !query analysis +org.apache.spark.SparkException +{ + "errorClass" : "ROW_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", + "messageParameters" : { + "numExpr" : "2", + "numTarget" : "3" + } +} + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3, c1 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", + "messageParameters" : { + "numExpr" : "4", + "numTarget" : "3" + } +} + + +-- !query +SET VARIABLE (var1, var2, var1) = (SELECT c1, c2, c3, c1 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ASSIGNMENTS", + "sqlState" : "42701", + "messageParameters" : { + "nameList" : "`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var2 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DROP TEMPORARY VARIABLE var3 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE title = 'DEFAULT expression usage' +-- !query analysis +SetVariable [variablereference(system.session.title='SET VARIABLE - row assignment')] ++- Project [DEFAULT expression usage AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'default1' +-- !query analysis +CreateVariable defaultvalueexpression(cast(default1 as string), 'default1'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'default2' +-- !query analysis +CreateVariable defaultvalueexpression(cast(default2 as string), 'default2'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DECLARE OR REPLACE VARIABLE var3 STRING DEFAULT 'default3' +-- !query analysis +CreateVariable defaultvalueexpression(cast(default3 as string), 'default3'), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='default1')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE var1 = DEFAULT +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- Project [default1 AS DEFAULT#x] + +- OneRowRelation + + +-- !query +SELECT var1 AS `default` +-- !query analysis +Project [variablereference(system.session.var1='default1') AS default#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello1' +-- !query analysis +SetVariable [variablereference(system.session.var1='default1')] ++- Project [hello1 AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello2' +-- !query analysis +SetVariable [variablereference(system.session.var1='hello1')] ++- Project [hello2 AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello3' +-- !query analysis +SetVariable [variablereference(system.session.var1='hello2')] ++- Project [hello3 AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE var1 = DEFAULT, var2 = DEFAULT, var3 = DEFAULT +-- !query analysis +SetVariable [variablereference(system.session.var1='hello3'), variablereference(system.session.var2='default2'), variablereference(system.session.var3='default3')] ++- Project [default1 AS DEFAULT#x, default2 AS DEFAULT#x, default3 AS DEFAULT#x] + +- OneRowRelation + + +-- !query +SELECT var1 AS `default1`, var2 AS `default2`, var3 AS `default3` +-- !query analysis +Project [variablereference(system.session.var1='default1') AS default1#x, variablereference(system.session.var2='default2') AS default2#x, variablereference(system.session.var3='default3') AS default3#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='default1')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1) AS T(c1)) +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- Project [default1 AS DEFAULT#x] + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +SELECT var1 AS `default` +-- !query analysis +Project [variablereference(system.session.var1='default1') AS default#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='default1')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES('world') AS T(default)) +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- Project [DEFAULT#x] + +- SubqueryAlias T + +- LocalRelation [default#x] + + +-- !query +SELECT var1 AS `world` +-- !query analysis +Project [variablereference(system.session.var1='world') AS world#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='world')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1) AS T(c1) LIMIT 1) +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- GlobalLimit 1 + +- LocalLimit 1 + +- Project [default1 AS DEFAULT#x] + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +SELECT var1 AS `default` +-- !query analysis +Project [variablereference(system.session.var1='default1') AS default#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='default1')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1) LIMIT 1 OFFSET 1) +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- GlobalLimit 1 + +- LocalLimit 1 + +- Offset 1 + +- Project [default1 AS DEFAULT#x] + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +SELECT var1 AS `default` +-- !query analysis +Project [variablereference(system.session.var1='default1') AS default#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='default1')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1) OFFSET 1) +-- !query analysis +org.apache.spark.SparkException +{ + "errorClass" : "ROW_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +SELECT var1 AS `default` +-- !query analysis +Project [variablereference(system.session.var1='hello') AS default#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- Project [hello AS var1#x] + +- OneRowRelation + + +-- !query +SET VARIABLE (var1) = (WITH v1(c1) AS (VALUES(1) AS T(c1)) SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1)) +-- !query analysis +org.apache.spark.SparkException +{ + "errorClass" : "ROW_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +SELECT var1 AS `default` +-- !query analysis +Project [variablereference(system.session.var1='hello') AS default#x] ++- OneRowRelation + + +-- !query +SET VARIABLE var1 = 'Hello' || DEFAULT +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DEFAULT_PLACEMENT_INVALID", + "sqlState" : "42608" +} + + +-- !query +SET VARIABLE (var1) = (VALUES(DEFAULT)) +-- !query analysis +SetVariable [variablereference(system.session.var1='hello')] ++- LocalRelation [col1#x] + + +-- !query +SET VARIABLE (var1) = (WITH v1(c1) AS (VALUES(1) AS T(c1)) SELECT DEFAULT + 1 FROM VALUES(1),(2),(3) AS T(c1)) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DEFAULT_PLACEMENT_INVALID", + "sqlState" : "42608" +} + + +-- !query +SET VARIABLE var1 = session.default +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`session`.`default`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 35, + "fragment" : "session.default" + } ] +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DROP TEMPORARY VARIABLE var2 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var2 + + +-- !query +DROP TEMPORARY VARIABLE var3 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var3 + + +-- !query +SET VARIABLE title = 'SET command' +-- !query analysis +SetVariable [variablereference(system.session.title='DEFAULT expression usage')] ++- Project [SET command AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET x.var1 = 5 +-- !query analysis +SetCommand (x.var1,Some(5)) + + +-- !query +SET x = 5 +-- !query analysis +SetCommand (x,Some(5)) + + +-- !query +SET system.x.var = 5 +-- !query analysis +SetCommand (system.x.var,Some(5)) + + +-- !query +SET x.session.var1 = 5 +-- !query analysis +SetCommand (x.session.var1,Some(5)) + + +-- !query +SET var1 = 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`var1`" + } +} + + +-- !query +SET session.var1 = 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`session`.`var1`" + } +} + + +-- !query +SET system.session.var1 = 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +SET vAr1 = 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`vAr1`" + } +} + + +-- !query +SET seSSion.var1 = 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`seSSion`.`var1`" + } +} + + +-- !query +SET sYStem.session.var1 = 5 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`sYStem`.`session`.`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SELECT var1 AS `2` FROM VALUES(2) AS T(var1) +-- !query analysis +Project [var1#x AS 2#x] ++- SubqueryAlias T + +- LocalRelation [var1#x] + + +-- !query +SELECT c1 AS `2` FROM VALUES(2) AS T(var1), LATERAL(SELECT var1) AS TT(c1) +-- !query analysis +Project [c1#x AS 2#x] ++- LateralJoin lateral-subquery#x [], Inner + : +- SubqueryAlias TT + : +- Project [var1#x AS c1#x] + : +- Project [variablereference(system.session.var1=1) AS var1#x] + : +- OneRowRelation + +- SubqueryAlias T + +- LocalRelation [var1#x] + + +-- !query +SELECT session.var1 AS `1` FROM VALUES(2) AS T(var1) +-- !query analysis +Project [variablereference(system.session.var1=1) AS 1#x] ++- SubqueryAlias T + +- LocalRelation [var1#x] + + +-- !query +SELECT c1 AS `1` FROM VALUES(2) AS T(var1), LATERAL(SELECT session.var1) AS TT(c1) +-- !query analysis +Project [c1#x AS 1#x] ++- LateralJoin lateral-subquery#x [], Inner + : +- SubqueryAlias TT + : +- Project [var1#x AS c1#x] + : +- Project [variablereference(system.session.var1=1) AS var1#x] + : +- OneRowRelation + +- SubqueryAlias T + +- LocalRelation [var1#x] + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'variable references -- visibility' +-- !query analysis +SetVariable [variablereference(system.session.title='SET command')] ++- Project [variable references -- visibility AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +VALUES (var1) +-- !query analysis +LocalRelation [col1#x] + + +-- !query +SELECT var1 +-- !query analysis +Project [variablereference(system.session.var1=1) AS var1#x] ++- OneRowRelation + + +-- !query +SELECT sum(var1) FROM VALUES(1) +-- !query analysis +Aggregate [sum(variablereference(system.session.var1=1)) AS sum(variablereference(system.session.var1=1) AS var1)#xL] ++- LocalRelation [col1#x] + + +-- !query +SELECT var1 + SUM(0) FROM VALUES(1) +-- !query analysis +Aggregate [(cast(variablereference(system.session.var1=1) as bigint) + sum(0)) AS (variablereference(system.session.var1=1) AS var1 + sum(0))#xL] ++- LocalRelation [col1#x] + + +-- !query +SELECT substr('12345', var1, 1) +-- !query analysis +Project [substr(12345, variablereference(system.session.var1=1), 1) AS substr(12345, variablereference(system.session.var1=1) AS var1, 1)#x] ++- OneRowRelation + + +-- !query +SELECT 1 FROM VALUES(1, 2) AS T(c1, c2) GROUP BY c1 + var1 +-- !query analysis +Aggregate [(c1#x + variablereference(system.session.var1=1))], [1 AS 1#x] ++- SubqueryAlias T + +- LocalRelation [c1#x, c2#x] + + +-- !query +SELECT c1, sum(c2) FROM VALUES(1, 2) AS T(c1, c2) GROUP BY c1 HAVING sum(c1) != var1 +-- !query analysis +Project [c1#x, sum(c2)#xL] ++- Filter NOT (sum(c1#x)#xL = cast(variablereference(system.session.var1=1) as bigint)) + +- Aggregate [c1#x], [c1#x, sum(c2#x) AS sum(c2)#xL, sum(c1#x) AS sum(c1#x)#xL] + +- SubqueryAlias T + +- LocalRelation [c1#x, c2#x] + + +-- !query +SELECT 1 FROM VALUES(1) AS T(c1) WHERE c1 IN (var1) +-- !query analysis +Project [1 AS 1#x] ++- Filter c1#x IN (variablereference(system.session.var1=1)) + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +SELECT sum(c1) FILTER (c1 != var1) FROM VALUES(1, 2), (2, 3) AS T(c1, c2) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'('", + "hint" : "" + } +} + + +-- !query +SELECT array(1, 2, 4)[var1] +-- !query analysis +Project [array(1, 2, 4)[variablereference(system.session.var1=1)] AS array(1, 2, 4)[variablereference(system.session.var1=1) AS var1]#x] ++- OneRowRelation + + +-- !query +SELECT 1 FROM VALUES(1) AS T(c1) WHERE c1 = var1 +-- !query analysis +Project [1 AS 1#x] ++- Filter (c1#x = variablereference(system.session.var1=1)) + +- SubqueryAlias T + +- LocalRelation [c1#x] + + +-- !query +WITH v1 AS (SELECT var1 AS c1) SELECT c1 AS `1` FROM v1 +-- !query analysis +WithCTE +:- CTERelationDef xxxx, false +: +- SubqueryAlias v1 +: +- Project [variablereference(system.session.var1=1) AS c1#x] +: +- OneRowRelation ++- Project [c1#x AS 1#x] + +- SubqueryAlias v1 + +- CTERelationRef xxxx, true, [c1#x] + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW v AS SELECT var1 AS c1 +-- !query analysis +CreateViewCommand `v`, SELECT var1 AS c1, false, true, LocalTempView, true + +- Project [variablereference(system.session.var1=1) AS c1#x] + +- OneRowRelation + + +-- !query +SELECT * FROM v +-- !query analysis +Project [c1#x] ++- SubqueryAlias v + +- View (`v`, [c1#x]) + +- Project [cast(c1#x as int) AS c1#x] + +- Project [variablereference(system.session.var1=1) AS c1#x] + +- OneRowRelation + + +-- !query +DROP VIEW v +-- !query analysis +DropTempViewCommand v + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +SET VARIABLE title = 'variable references -- prohibited' +-- !query analysis +SetVariable [variablereference(system.session.title='variable references -- visibility')] ++- Project [variable references -- prohibited AS title#x] + +- OneRowRelation + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query analysis +CreateVariable defaultvalueexpression(cast(1 as int), 1), true ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 + + +-- !query +CREATE OR REPLACE VIEW v AS SELECT var1 AS c1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_TEMP_OBJ_REFERENCE", + "messageParameters" : { + "obj" : "VIEW", + "objName" : "`spark_catalog`.`default`.`v`", + "tempObj" : "VARIABLE", + "tempObjName" : "`var1`" + } +} + + +-- !query +DROP VIEW IF EXISTS V +-- !query analysis +DropTableCommand `spark_catalog`.`default`.`V`, true, true, false + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query analysis +DropVariable false ++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/table-aliases.sql.out index 0c8d0b8f2693f..cfd36d29d270d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/table-aliases.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/table-aliases.sql.out @@ -57,10 +57,11 @@ SELECT * FROM testData AS t(col1, col2, col3) -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1028", + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", "messageParameters" : { - "columnSize" : "3", - "outputSize" : "2" + "numExpr" : "3", + "numTarget" : "2" }, "queryContext" : [ { "objectType" : "", @@ -77,10 +78,11 @@ SELECT * FROM testData AS t(col1) -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1028", + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", "messageParameters" : { - "columnSize" : "1", - "outputSize" : "2" + "numExpr" : "1", + "numTarget" : "2" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql new file mode 100644 index 0000000000000..4992453603ced --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql @@ -0,0 +1,374 @@ +SET spark.sql.ansi.enabled = true; + +DECLARE title STRING; + +SET VARIABLE title = '-- Basic sanity --'; +DECLARE var1 INT = 5; +SELECT var1; +SET VARIABLE var1 = 6; +SELECT var1; +DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'Create Variable - Success Cases'; +DECLARE VARIABLE var1 INT; +SELECT 'Expect: INT, NULL', typeof(var1), var1; + +DECLARE OR REPLACE VARIABLE var1 DOUBLE; +SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1; + +DROP TEMPORARY VARIABLE var1; +DECLARE OR REPLACE VARIABLE var1 TIMESTAMP; +SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1; + +SET VARIABLE title = 'Create Variable - Failure Cases'; +-- No support for IF NOT EXISTS +DECLARE VARIABLE IF NOT EXISTS var1 INT; +DROP TEMPORARY VARIABLE IF EXISTS var1; + +SET VARIABLE title = 'Drop Variable'; +DECLARE VARIABLE var1 INT; +SELECT var1; +DROP TEMPORARY VARIABLE var1; + +-- Variable is gone +SELECT var1; +DROP TEMPORARY VARIABLE var1; + +-- Success +DROP TEMPORARY VARIABLE IF EXISTS var1; + +-- Fail: TEMPORARY is mandatory on DROP +DECLARE VARIABLE var1 INT; +DROP VARIABLE var1; +DROP VARIABLE system.session.var1; +DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'Test qualifiers - success'; +DECLARE VARIABLE var1 INT DEFAULT 1; +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified; +SET VARIABLE var1 = 2; +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified; + +DECLARE OR REPLACE VARIABLE session.var1 INT DEFAULT 1; +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified; +SET VARIABLE session.var1 = 2; +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified; + +DECLARE OR REPLACE VARIABLE system.session.var1 INT DEFAULT 1; +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified; +SET VARIABLE system.session.var1 = 2; +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified; + +DECLARE OR REPLACE VARIABLE sySteM.sEssIon.vAr1 INT DEFAULT 1; +SELECT 1 as Expected, var1 as Unqualified, sessIon.Var1 AS SchemaQualified, System.sessiOn.var1 AS fullyQualified; +SET VARIABLE sYstem.sesSiOn.vaR1 = 2; +SELECT 2 as Expected, VAR1 as Unqualified, SESSION.VAR1 AS SchemaQualified, SYSTEM.SESSION.VAR1 AS fullyQualified; + +DECLARE OR REPLACE VARIABLE var1 INT; +DROP TEMPORARY VARIABLE var1; +DROP TEMPORARY VARIABLE var1; + +DECLARE OR REPLACE VARIABLE var1 INT; +DROP TEMPORARY VARIABLE session.var1; +DROP TEMPORARY VARIABLE var1; + +DECLARE OR REPLACE VARIABLE var1 INT; +DROP TEMPORARY VARIABLE system.session.var1; +DROP TEMPORARY VARIABLE var1; + +DECLARE OR REPLACE VARIABLE var1 INT; +DROP TEMPORARY VARIABLE sysTem.sesSion.vAr1; +DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'Test qualifiers - fail'; +DECLARE OR REPLACE VARIABLE builtin.var1 INT; +DECLARE OR REPLACE VARIABLE system.sesion.var1 INT; +DECLARE OR REPLACE VARIABLE sys.session.var1 INT; + +DECLARE OR REPLACE VARIABLE var1 INT; +SELECT var; +SELECT ses.var1; +SELECT b.sesson.var1; +SELECT builtn.session.var1; + +SET VARIABLE ses.var1 = 1; +SET VARIABLE builtn.session.var1 = 1; + +SET VARIABLE title = 'Test DEFAULT on create - success'; +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1; +SELECT 1 AS Expected, var1 AS result; + +DECLARE OR REPLACE VARIABLE var1 DOUBLE DEFAULT 1 + RAND(5); +SELECT true AS Expected, var1 >= 1 AS result; + +DECLARE OR REPLACE VARIABLE var1 = 'Hello'; +SELECT 'STRING, Hello' AS Expected, typeof(var1) AS type, var1 AS result; + +DECLARE OR REPLACE VARIABLE var1 DEFAULT NULL; +SELECT 'VOID, NULL' AS Expected, typeof(var1) AS type, var1 AS result; + +DECLARE OR REPLACE VARIABLE INT DEFAULT 5.0; +SELECT 'INT, 5' AS Expected, typeof(var1) AS type, var1 AS result; + +DECLARE OR REPLACE VARIABLE var1 MAP DEFAULT MAP('Hello', 5.1, 'World', -7.1E10); +SELECT 'MAP, [Hello -> 5.1, World -> -7E10]' AS Expected, typeof(var1) AS type, var1 AS result; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT NULL; +SELECT 'NULL' AS Expected, var1 AS result; + +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT CURRENT_DATABASE(); +SELECT 'true' AS Expected, length(var1) > 0 AS result; + +DROP TEMPORARY VARIABLE var1; + +-- No type and no default is not allowed +DECLARE var1; + +-- TBD: Store assignment cast test + +SET VARIABLE title = 'Test DEFAULT on create - failures'; + +-- No subqueries allowed in DEFAULT expression +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT (SELECT c1 FROM VALUES(1) AS T(c1)); + +-- Incompatible type +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 'hello'; + +-- Runtime error +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 / 0; + +-- Runtime overflow on assignment +DECLARE OR REPLACE VARIABLE var1 SMALLINT DEFAULT 100000; + +SET VARIABLE title = 'SET VARIABLE - single target'; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5; + +SET VARIABLE var1 = 7; +SELECT var1; + +SET VAR var1 = 8; +SELECT var1; + +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1) AS T(c1)); +SELECT var1; + +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1) AS T(c1) WHERE 1=0); +SELECT var1 AS `null`; + +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1.0) AS T(c1)); +SELECT var1; + +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1.0E10) AS T(c1)); +SELECT var1; + +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1), (2) AS T(c1)); + +SET VARIABLE var1 = (SELECT c1, c1 FROM VALUES(1), (2) AS T(c1)); + +SET VARIABLE var1 = (SELECT c1 FROM VALUES('hello') AS T(c1)); + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5; +SET VARIABLE var1 = var1 + 1; +SELECT var1; + +-- TBD store assignment cast test + +DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'SET VARIABLE - comma separated target'; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5; +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello'; +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2; + +SET VARIABLE var1 = 6, var2 = 'world', var3 = pi(); +SELECT var1 AS `6`, var2 AS `world` , var3 as `3.14...`; + +SET VAR var1 = 7, var2 = 'universe', var3 = -1; +SELECT var1 AS `7`, var2 AS `universe` , var3 as `-1`; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5; +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello'; +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2; + +SET VARIABLE var1 = var3, var2 = ascii(var1), var3 = var1; +SELECT var1 AS `2`, var2 AS `104`, var3 AS `5`; + +SET VARIABLE var1 = var3, var2 = INTERVAL'5' HOUR, var3 = var1; + +-- Duplicates check +SET VARIABLE var1 = 1, var2 = 0, vAr1 = 1; + +DROP TEMPORARY VARIABLE var1; +DROP TEMPORARY VARIABLE var2; +DROP TEMPORARY VARIABLE var3; + +SET VARIABLE title = 'SET VARIABLE - row assignment'; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5; +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello'; +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2; + +-- Must have at least one target +SET VARIABLE (var1) = (SELECT c1 FROM VALUES(1) AS T(c1)); +SELECT var1; + +SET VAR (var1) = (SELECT c1 FROM VALUES(2) AS T(c1)); +SELECT var1; + +SET VARIABLE (var1, var2) = (SELECT c1, c2 FROM VALUES(10, 11) AS T(c1, c2)); +SELECT var1 AS `10`, var2 AS `11`; + +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)); +SELECT var1 AS `100`, var2 AS `110`, var3 AS `120`; + +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120) AS T(c1, c2, c3) WHERE 1 = 0); +SELECT var1 AS `NULL`, var2 AS `NULL`, var3 AS `NULL`; + +-- Fail no target +SET VARIABLE () = (SELECT 1); + +-- Fail, more than one row +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120), (-100, -110, -120) AS T(c1, c2, c3)); + +-- Fail, not enough columns +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)); + +-- Fail, too many columns +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3, c1 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)); + +-- Fail, duplicated target +SET VARIABLE (var1, var2, var1) = (SELECT c1, c2, c3, c1 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)); + +DROP TEMPORARY VARIABLE var1; +DROP TEMPORARY VARIABLE var2; +DROP TEMPORARY VARIABLE var3; + +SET VARIABLE title = 'DEFAULT expression usage'; + +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'default1'; +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'default2'; +DECLARE OR REPLACE VARIABLE var3 STRING DEFAULT 'default3'; + +SET VARIABLE var1 = 'hello'; + +SET VARIABLE var1 = DEFAULT; +SELECT var1 AS `default`; + +SET VARIABLE var1 = 'hello1'; +SET VARIABLE var1 = 'hello2'; +SET VARIABLE var1 = 'hello3'; +SET VARIABLE var1 = DEFAULT, var2 = DEFAULT, var3 = DEFAULT; +SELECT var1 AS `default1`, var2 AS `default2`, var3 AS `default3`; + +SET VARIABLE var1 = 'hello'; +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1) AS T(c1)); +SELECT var1 AS `default`; + +SET VARIABLE var1 = 'hello'; +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES('world') AS T(default)); +SELECT var1 AS `world`; + +SET VARIABLE var1 = 'hello'; +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1) AS T(c1) LIMIT 1); +SELECT var1 AS `default`; + +SET VARIABLE var1 = 'hello'; +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1) LIMIT 1 OFFSET 1); +SELECT var1 AS `default`; + +SET VARIABLE var1 = 'hello'; +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1) OFFSET 1); +SELECT var1 AS `default`; + +SET VARIABLE var1 = 'hello'; +SET VARIABLE (var1) = (WITH v1(c1) AS (VALUES(1) AS T(c1)) SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1)); +SELECT var1 AS `default`; + +-- Failure +SET VARIABLE var1 = 'Hello' || DEFAULT; + +SET VARIABLE (var1) = (VALUES(DEFAULT)); + +SET VARIABLE (var1) = (WITH v1(c1) AS (VALUES(1) AS T(c1)) SELECT DEFAULT + 1 FROM VALUES(1),(2),(3) AS T(c1)); + +SET VARIABLE var1 = session.default; + +DROP TEMPORARY VARIABLE var1; +DROP TEMPORARY VARIABLE var2; +DROP TEMPORARY VARIABLE var3; + +SET VARIABLE title = 'SET command'; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1; + +-- Sanity: These are all configs +SET x.var1 = 5; +SET x = 5; +SET system.x.var = 5; +SET x.session.var1 = 5; + +-- These raise errors: UNSUPPORTED_FEATURE.SET_VARIABLE_IN_SET +SET var1 = 5; +SET session.var1 = 5; +SET system.session.var1 = 5; +SET vAr1 = 5; +SET seSSion.var1 = 5; +SET sYStem.session.var1 = 5; + +DROP TEMPORARY VARIABLE var1; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1; + +SELECT var1 AS `2` FROM VALUES(2) AS T(var1); + +SELECT c1 AS `2` FROM VALUES(2) AS T(var1), LATERAL(SELECT var1) AS TT(c1); + +SELECT session.var1 AS `1` FROM VALUES(2) AS T(var1); + +SELECT c1 AS `1` FROM VALUES(2) AS T(var1), LATERAL(SELECT session.var1) AS TT(c1); + +DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'variable references -- visibility'; +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1; + +VALUES (var1); + +SELECT var1; + +SELECT sum(var1) FROM VALUES(1); +SELECT var1 + SUM(0) FROM VALUES(1); +SELECT substr('12345', var1, 1); +SELECT 1 FROM VALUES(1, 2) AS T(c1, c2) GROUP BY c1 + var1; +SELECT c1, sum(c2) FROM VALUES(1, 2) AS T(c1, c2) GROUP BY c1 HAVING sum(c1) != var1; +SELECT 1 FROM VALUES(1) AS T(c1) WHERE c1 IN (var1); +SELECT sum(c1) FILTER (c1 != var1) FROM VALUES(1, 2), (2, 3) AS T(c1, c2); +SELECT array(1, 2, 4)[var1]; + +-- TBD usage in body of lambda function + +SELECT 1 FROM VALUES(1) AS T(c1) WHERE c1 = var1; + +WITH v1 AS (SELECT var1 AS c1) SELECT c1 AS `1` FROM v1; + +CREATE OR REPLACE TEMPORARY VIEW v AS SELECT var1 AS c1; +SELECT * FROM v; +DROP VIEW v; + +DROP TEMPORARY VARIABLE var1; + +SET VARIABLE title = 'variable references -- prohibited'; + +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1; + +-- Known broken for parameters as well +--DROP TABLE IF EXISTS T; +--CREATE TABLE T(c1 INT DEFAULT (var1)); +--DROP TABLE IF EXISTS T; + +CREATE OR REPLACE VIEW v AS SELECT var1 AS c1; +DROP VIEW IF EXISTS V; + +DROP TEMPORARY VARIABLE var1; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index fe7bec0219162..2abb5cee11914 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -77,6 +77,7 @@ DAYS false DBPROPERTIES false DEC false DECIMAL false +DECLARE false DEFAULT false DEFINED false DELETE false @@ -315,7 +316,9 @@ USE false USER true USING true VALUES false +VAR false VARCHAR false +VARIABLE false VERSION false VIEW false VIEWS false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index a4fd9c82cf095..716e2a32e7fce 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -77,6 +77,7 @@ DAYS false DBPROPERTIES false DEC false DECIMAL false +DECLARE false DEFAULT false DEFINED false DELETE false @@ -315,7 +316,9 @@ USE false USER false USING false VALUES false +VAR false VARCHAR false +VARIABLE false VERSION false VIEW false VIEWS false diff --git a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out new file mode 100644 index 0000000000000..b3146e645c525 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out @@ -0,0 +1,2286 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SET spark.sql.ansi.enabled = true +-- !query schema +struct +-- !query output +spark.sql.ansi.enabled true + + +-- !query +DECLARE title STRING +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = '-- Basic sanity --' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE var1 INT = 5 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +5 + + +-- !query +SET VARIABLE var1 = 6 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +6 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'Create Variable - Success Cases' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'Expect: INT, NULL', typeof(var1), var1 +-- !query schema +struct +-- !query output +Expect: INT, NULL int NULL + + +-- !query +DECLARE OR REPLACE VARIABLE var1 DOUBLE +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'Expect: DOUBLE, NULL', typeof(var1), var1 +-- !query schema +struct +-- !query output +Expect: DOUBLE, NULL double NULL + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 TIMESTAMP +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'Expect: TIMESTAMP, NULL', typeof(var1), var1 +-- !query schema +struct +-- !query output +Expect: TIMESTAMP, NULL timestamp NULL + + +-- !query +SET VARIABLE title = 'Create Variable - Failure Cases' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE VARIABLE IF NOT EXISTS var1 INT +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'EXISTS'", + "hint" : "" + } +} + + +-- !query +DROP TEMPORARY VARIABLE IF EXISTS var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'Drop Variable' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +NULL + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 11, + "fragment" : "var1" + } ] +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE IF EXISTS var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VARIABLE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'VARIABLE'", + "hint" : "" + } +} + + +-- !query +DROP VARIABLE system.session.var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'VARIABLE'", + "hint" : "" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'Test qualifiers - success' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE VARIABLE var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query schema +struct +-- !query output +1 1 1 1 + + +-- !query +SET VARIABLE var1 = 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query schema +struct +-- !query output +2 2 2 2 + + +-- !query +DECLARE OR REPLACE VARIABLE session.var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query schema +struct +-- !query output +1 1 1 1 + + +-- !query +SET VARIABLE session.var1 = 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query schema +struct +-- !query output +2 2 2 2 + + +-- !query +DECLARE OR REPLACE VARIABLE system.session.var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query schema +struct +-- !query output +1 1 1 1 + + +-- !query +SET VARIABLE system.session.var1 = 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 2 as Expected, var1 as Unqualified, session.var1 AS SchemaQualified, system.session.var1 AS fullyQualified +-- !query schema +struct +-- !query output +2 2 2 2 + + +-- !query +DECLARE OR REPLACE VARIABLE sySteM.sEssIon.vAr1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 1 as Expected, var1 as Unqualified, sessIon.Var1 AS SchemaQualified, System.sessiOn.var1 AS fullyQualified +-- !query schema +struct +-- !query output +1 1 1 1 + + +-- !query +SET VARIABLE sYstem.sesSiOn.vaR1 = 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 2 as Expected, VAR1 as Unqualified, SESSION.VAR1 AS SchemaQualified, SYSTEM.SESSION.VAR1 AS fullyQualified +-- !query schema +struct +-- !query output +2 2 2 2 + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE session.var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE system.session.var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE sysTem.sesSion.vAr1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "VARIABLE_NOT_FOUND", + "sqlState" : "42883", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +SET VARIABLE title = 'Test qualifiers - fail' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE builtin.var1 INT +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`system`.`session`", + "variableName" : "`builtin`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE system.sesion.var1 INT +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`system`.`session`", + "variableName" : "`system`.`sesion`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE sys.session.var1 INT +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`system`.`session`", + "variableName" : "`sys`.`session`.`var1`" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`var`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 10, + "fragment" : "var" + } ] +} + + +-- !query +SELECT ses.var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`ses`.`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 15, + "fragment" : "ses.var1" + } ] +} + + +-- !query +SELECT b.sesson.var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`b`.`sesson`.`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "b.sesson.var1" + } ] +} + + +-- !query +SELECT builtn.session.var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`builtn`.`session`.`var1`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "builtn.session.var1" + } ] +} + + +-- !query +SET VARIABLE ses.var1 = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`SYSTEM`.`SESSION`", + "variableName" : "`ses`.`var1`" + } +} + + +-- !query +SET VARIABLE builtn.session.var1 = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_VARIABLE", + "sqlState" : "42883", + "messageParameters" : { + "searchPath" : "`SYSTEM`.`SESSION`", + "variableName" : "`builtn`.`session`.`var1`" + } +} + + +-- !query +SET VARIABLE title = 'Test DEFAULT on create - success' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 1 AS Expected, var1 AS result +-- !query schema +struct +-- !query output +1 1 + + +-- !query +DECLARE OR REPLACE VARIABLE var1 DOUBLE DEFAULT 1 + RAND(5) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT true AS Expected, var1 >= 1 AS result +-- !query schema +struct +-- !query output +true true + + +-- !query +DECLARE OR REPLACE VARIABLE var1 = 'Hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'STRING, Hello' AS Expected, typeof(var1) AS type, var1 AS result +-- !query schema +struct +-- !query output +STRING, Hello string Hello + + +-- !query +DECLARE OR REPLACE VARIABLE var1 DEFAULT NULL +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'VOID, NULL' AS Expected, typeof(var1) AS type, var1 AS result +-- !query schema +struct +-- !query output +VOID, NULL void NULL + + +-- !query +DECLARE OR REPLACE VARIABLE INT DEFAULT 5.0 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'INT, 5' AS Expected, typeof(var1) AS type, var1 AS result +-- !query schema +struct +-- !query output +INT, 5 void NULL + + +-- !query +DECLARE OR REPLACE VARIABLE var1 MAP DEFAULT MAP('Hello', 5.1, 'World', -7.1E10) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'MAP, [Hello -> 5.1, World -> -7E10]' AS Expected, typeof(var1) AS type, var1 AS result +-- !query schema +struct> +-- !query output +MAP, [Hello -> 5.1, World -> -7E10] map {"Hello":5.1,"World":-7.1E10} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT NULL +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'NULL' AS Expected, var1 AS result +-- !query schema +struct +-- !query output +NULL NULL + + +-- !query +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT CURRENT_DATABASE() +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT 'true' AS Expected, length(var1) > 0 AS result +-- !query schema +struct +-- !query output +true true + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE var1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "INVALID_SQL_SYNTAX.VARIABLE_TYPE_OR_DEFAULT_REQUIRED", + "sqlState" : "42000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 12, + "fragment" : "DECLARE var1" + } ] +} + + +-- !query +SET VARIABLE title = 'Test DEFAULT on create - failures' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT (SELECT c1 FROM VALUES(1) AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + "messageParameters" : { + "colName" : "`system`.`session`.`var1`", + "defaultValue" : "(SELECT c1 FROM VALUES(1) AS T(c1))", + "statement" : "CRETE VARIABLE" + } +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 'hello' +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkNumberFormatException +{ + "errorClass" : "CAST_INVALID_INPUT", + "sqlState" : "22018", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "expression" : "'hello'", + "sourceType" : "\"STRING\"", + "targetType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 'hello'" + } ] +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 / 0 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 46, + "stopIndex" : 50, + "fragment" : "1 / 0" + } ] +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 SMALLINT DEFAULT 100000 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "CAST_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "sourceType" : "\"INT\"", + "targetType" : "\"SMALLINT\"", + "value" : "100000" + } +} + + +-- !query +SET VARIABLE title = 'SET VARIABLE - single target' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = 7 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +7 + + +-- !query +SET VAR var1 = 8 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +8 + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1) AS T(c1)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +1 + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1) AS T(c1) WHERE 1=0) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `null` +-- !query schema +struct +-- !query output +NULL + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1.0) AS T(c1)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +1 + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1.0E10) AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "CAST_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "sourceType" : "\"DOUBLE\"", + "targetType" : "\"INT\"", + "value" : "1.0E10D" + } +} + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +1 + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES(1), (2) AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 60, + "fragment" : "(SELECT c1 FROM VALUES(1), (2) AS T(c1))" + } ] +} + + +-- !query +SET VARIABLE var1 = (SELECT c1, c1 FROM VALUES(1), (2) AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_SUBQUERY_EXPRESSION.SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN", + "sqlState" : "42823", + "messageParameters" : { + "number" : "2" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 64, + "fragment" : "(SELECT c1, c1 FROM VALUES(1), (2) AS T(c1))" + } ] +} + + +-- !query +SET VARIABLE var1 = (SELECT c1 FROM VALUES('hello') AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkNumberFormatException +{ + "errorClass" : "CAST_INVALID_INPUT", + "sqlState" : "22018", + "messageParameters" : { + "ansiConfig" : "\"spark.sql.ansi.enabled\"", + "expression" : "'hello'", + "sourceType" : "\"STRING\"", + "targetType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 61, + "fragment" : "SET VARIABLE var1 = (SELECT c1 FROM VALUES('hello') AS T(c1))" + } ] +} + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = var1 + 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +6 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'SET VARIABLE - comma separated target' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = 6, var2 = 'world', var3 = pi() +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `6`, var2 AS `world` , var3 as `3.14...` +-- !query schema +struct<6:int,world:string,3.14...:double> +-- !query output +6 world 3.141592653589793 + + +-- !query +SET VAR var1 = 7, var2 = 'universe', var3 = -1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `7`, var2 AS `universe` , var3 as `-1` +-- !query schema +struct<7:int,universe:string,-1:double> +-- !query output +7 universe -1.0 + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = var3, var2 = ascii(var1), var3 = var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `2`, var2 AS `104`, var3 AS `5` +-- !query schema +struct<2:int,104:string,5:double> +-- !query output +2 53 5.0 + + +-- !query +SET VARIABLE var1 = var3, var2 = INTERVAL'5' HOUR, var3 = var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = 1, var2 = 0, vAr1 = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ASSIGNMENTS", + "sqlState" : "42701", + "messageParameters" : { + "nameList" : "`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var3 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'SET VARIABLE - row assignment' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 5 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var3 DOUBLE DEFAULT 2 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (SELECT c1 FROM VALUES(1) AS T(c1)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +1 + + +-- !query +SET VAR (var1) = (SELECT c1 FROM VALUES(2) AS T(c1)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +2 + + +-- !query +SET VARIABLE (var1, var2) = (SELECT c1, c2 FROM VALUES(10, 11) AS T(c1, c2)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `10`, var2 AS `11` +-- !query schema +struct<10:int,11:string> +-- !query output +10 11 + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `100`, var2 AS `110`, var3 AS `120` +-- !query schema +struct<100:int,110:string,120:double> +-- !query output +100 110 120.0 + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120) AS T(c1, c2, c3) WHERE 1 = 0) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `NULL`, var2 AS `NULL`, var3 AS `NULL` +-- !query schema +struct +-- !query output +NULL NULL NULL + + +-- !query +SET VARIABLE () = (SELECT 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "INVALID_SET_SYNTAX", + "sqlState" : "42000", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 28, + "fragment" : "SET VARIABLE () = (SELECT 1)" + } ] +} + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3 FROM VALUES(100, 110, 120), (-100, -110, -120) AS T(c1, c2, c3)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +{ + "errorClass" : "ROW_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", + "messageParameters" : { + "numExpr" : "2", + "numTarget" : "3" + } +} + + +-- !query +SET VARIABLE (var1, var2, var3) = (SELECT c1, c2, c3, c1 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", + "messageParameters" : { + "numExpr" : "4", + "numTarget" : "3" + } +} + + +-- !query +SET VARIABLE (var1, var2, var1) = (SELECT c1, c2, c3, c1 FROM VALUES(100, 110, 120) AS T(c1, c2, c3)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DUPLICATE_ASSIGNMENTS", + "sqlState" : "42701", + "messageParameters" : { + "nameList" : "`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var3 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'DEFAULT expression usage' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT 'default1' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var2 STRING DEFAULT 'default2' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var3 STRING DEFAULT 'default3' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = DEFAULT +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `default` +-- !query schema +struct +-- !query output +default1 + + +-- !query +SET VARIABLE var1 = 'hello1' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = 'hello2' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = 'hello3' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE var1 = DEFAULT, var2 = DEFAULT, var3 = DEFAULT +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `default1`, var2 AS `default2`, var3 AS `default3` +-- !query schema +struct +-- !query output +default1 default2 default3 + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1) AS T(c1)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `default` +-- !query schema +struct +-- !query output +default1 + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES('world') AS T(default)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `world` +-- !query schema +struct +-- !query output +world + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1) AS T(c1) LIMIT 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `default` +-- !query schema +struct +-- !query output +default1 + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1) LIMIT 1 OFFSET 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `default` +-- !query schema +struct +-- !query output +default1 + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1) OFFSET 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +{ + "errorClass" : "ROW_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +SELECT var1 AS `default` +-- !query schema +struct +-- !query output +hello + + +-- !query +SET VARIABLE var1 = 'hello' +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (WITH v1(c1) AS (VALUES(1) AS T(c1)) SELECT DEFAULT FROM VALUES(1),(2),(3) AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +{ + "errorClass" : "ROW_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +SELECT var1 AS `default` +-- !query schema +struct +-- !query output +hello + + +-- !query +SET VARIABLE var1 = 'Hello' || DEFAULT +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DEFAULT_PLACEMENT_INVALID", + "sqlState" : "42608" +} + + +-- !query +SET VARIABLE (var1) = (VALUES(DEFAULT)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE (var1) = (WITH v1(c1) AS (VALUES(1) AS T(c1)) SELECT DEFAULT + 1 FROM VALUES(1),(2),(3) AS T(c1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DEFAULT_PLACEMENT_INVALID", + "sqlState" : "42608" +} + + +-- !query +SET VARIABLE var1 = session.default +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`session`.`default`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 35, + "fragment" : "session.default" + } ] +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var3 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'SET command' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET x.var1 = 5 +-- !query schema +struct +-- !query output +x.var1 5 + + +-- !query +SET x = 5 +-- !query schema +struct +-- !query output +x 5 + + +-- !query +SET system.x.var = 5 +-- !query schema +struct +-- !query output +system.x.var 5 + + +-- !query +SET x.session.var1 = 5 +-- !query schema +struct +-- !query output +x.session.var1 5 + + +-- !query +SET var1 = 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`var1`" + } +} + + +-- !query +SET session.var1 = 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`session`.`var1`" + } +} + + +-- !query +SET system.session.var1 = 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`system`.`session`.`var1`" + } +} + + +-- !query +SET vAr1 = 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`vAr1`" + } +} + + +-- !query +SET seSSion.var1 = 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`seSSion`.`var1`" + } +} + + +-- !query +SET sYStem.session.var1 = 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.SET_VARIABLE_USING_SET", + "sqlState" : "0A000", + "messageParameters" : { + "variableName" : "`sYStem`.`session`.`var1`" + } +} + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT var1 AS `2` FROM VALUES(2) AS T(var1) +-- !query schema +struct<2:int> +-- !query output +2 + + +-- !query +SELECT c1 AS `2` FROM VALUES(2) AS T(var1), LATERAL(SELECT var1) AS TT(c1) +-- !query schema +struct<2:int> +-- !query output +1 + + +-- !query +SELECT session.var1 AS `1` FROM VALUES(2) AS T(var1) +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +SELECT c1 AS `1` FROM VALUES(2) AS T(var1), LATERAL(SELECT session.var1) AS TT(c1) +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'variable references -- visibility' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +VALUES (var1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT var1 +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT sum(var1) FROM VALUES(1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT var1 + SUM(0) FROM VALUES(1) +-- !query schema +struct<(variablereference(system.session.var1=1) AS var1 + sum(0)):bigint> +-- !query output +1 + + +-- !query +SELECT substr('12345', var1, 1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT 1 FROM VALUES(1, 2) AS T(c1, c2) GROUP BY c1 + var1 +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +SELECT c1, sum(c2) FROM VALUES(1, 2) AS T(c1, c2) GROUP BY c1 HAVING sum(c1) != var1 +-- !query schema +struct +-- !query output + + + +-- !query +SELECT 1 FROM VALUES(1) AS T(c1) WHERE c1 IN (var1) +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +SELECT sum(c1) FILTER (c1 != var1) FROM VALUES(1, 2), (2, 3) AS T(c1, c2) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'('", + "hint" : "" + } +} + + +-- !query +SELECT array(1, 2, 4)[var1] +-- !query schema +struct +-- !query output +2 + + +-- !query +SELECT 1 FROM VALUES(1) AS T(c1) WHERE c1 = var1 +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +WITH v1 AS (SELECT var1 AS c1) SELECT c1 AS `1` FROM v1 +-- !query schema +struct<1:int> +-- !query output +1 + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW v AS SELECT var1 AS c1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT * FROM v +-- !query schema +struct +-- !query output +1 + + +-- !query +DROP VIEW v +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + + + +-- !query +SET VARIABLE title = 'variable references -- prohibited' +-- !query schema +struct<> +-- !query output + + + +-- !query +DECLARE OR REPLACE VARIABLE var1 INT DEFAULT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE VIEW v AS SELECT var1 AS c1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_TEMP_OBJ_REFERENCE", + "messageParameters" : { + "obj" : "VIEW", + "objName" : "`spark_catalog`.`default`.`v`", + "tempObj" : "VARIABLE", + "tempObjName" : "`var1`" + } +} + + +-- !query +DROP VIEW IF EXISTS V +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TEMPORARY VARIABLE var1 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out index 22de4faf1ce32..5c05bb3f4c22b 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -40,10 +40,11 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1028", + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", "messageParameters" : { - "columnSize" : "3", - "outputSize" : "2" + "numExpr" : "3", + "numTarget" : "2" }, "queryContext" : [ { "objectType" : "", @@ -62,10 +63,11 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1028", + "errorClass" : "ASSIGNMENT_ARITY_MISMATCH", + "sqlState" : "42802", "messageParameters" : { - "columnSize" : "1", - "outputSize" : "2" + "numExpr" : "1", + "numTarget" : "2" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index a2f3d872a68e9..6f9cc66f24769 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -25,7 +25,7 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, FunctionRegistry, NoSuchTableException, ResolveSessionCatalog} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -169,6 +169,8 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { private val v2SessionCatalog = new V2SessionCatalog(v1SessionCatalog) + private val tempVariableManager = new TempVariableManager + private val catalogManager = { val manager = mock(classOf[CatalogManager]) when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => { @@ -182,6 +184,7 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { when(manager.currentNamespace).thenReturn(Array.empty[String]) when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + when(manager.tempVariableManager).thenReturn(tempVariableManager) manager } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 8eb0d5456c111..4eb65305de838 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedIdentifier, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog, TempVariableManager} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} @@ -184,6 +184,8 @@ class PlanResolutionSuite extends AnalysisTest { new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) createTempView(v1SessionCatalog, "v", LocalRelation(Nil), false) + private val tempVariableManager = new TempVariableManager + private val catalogManagerWithDefault = { val manager = mock(classOf[CatalogManager]) when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => { @@ -199,6 +201,7 @@ class PlanResolutionSuite extends AnalysisTest { when(manager.currentCatalog).thenReturn(testCat) when(manager.currentNamespace).thenReturn(Array.empty[String]) when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) + when(manager.tempVariableManager).thenReturn(tempVariableManager) manager } @@ -215,6 +218,7 @@ class PlanResolutionSuite extends AnalysisTest { when(manager.currentCatalog).thenReturn(v2SessionCatalog) when(manager.currentNamespace).thenReturn(Array("default")) when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) + when(manager.tempVariableManager).thenReturn(tempVariableManager) manager } @@ -1957,7 +1961,7 @@ class PlanResolutionSuite extends AnalysisTest { exception = intercept[AnalysisException] { parseAndResolve(mergeWithDefaultReferenceAsPartOfComplexExpression) }, - errorClass = "_LEGACY_ERROR_TEMP_1343", + errorClass = "DEFAULT_PLACEMENT_INVALID", parameters = Map.empty) val mergeWithDefaultReferenceForNonNullableCol = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index c6bfd8c14ddf7..ff82d178c34ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1205,7 +1205,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t values(false, default + 1)") }, - errorClass = "_LEGACY_ERROR_TEMP_1339", + errorClass = "DEFAULT_PLACEMENT_INVALID", parameters = Map.empty ) } @@ -1216,7 +1216,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("insert into t select false, default + 1") }, - errorClass = "_LEGACY_ERROR_TEMP_1339", + errorClass = "DEFAULT_PLACEMENT_INVALID", parameters = Map.empty ) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 45d1f70956a41..fd4c68e8ac25c 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -213,7 +213,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VARCHAR,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From c73660c3e7279f61fe6e2f6bbf88f410f7ce25a1 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Wed, 9 Aug 2023 18:08:23 +0900 Subject: [PATCH 30/30] [SPARK-44738][PYTHON][CONNECT] Add missing client metadata to calls ### What changes were proposed in this pull request? The refactoring for the re-attachable execution missed properly propagating the client metadata for the individual RPC calls. ### Why are the changes needed? Compatibility. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing UT Closes #42409 from grundprinzip/SPARK-44738. Authored-by: Martin Grund Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/reattach.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index c5c45904c9baa..c6b1beaa121b8 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -135,7 +135,9 @@ def _has_next(self) -> bool: if not attempt.is_first_try(): # on retry, the iterator is borked, so we need a new one self._iterator = iter( - self._stub.ReattachExecute(self._create_reattach_execute_request()) + self._stub.ReattachExecute( + self._create_reattach_execute_request(), metadata=self._metadata + ) ) if self._current is None: @@ -154,7 +156,8 @@ def _has_next(self) -> bool: while not has_next: self._iterator = iter( self._stub.ReattachExecute( - self._create_reattach_execute_request() + self._create_reattach_execute_request(), + metadata=self._metadata, ) ) # shouldn't change @@ -192,7 +195,7 @@ def target() -> None: can_retry=SparkConnectClient.retry_exception, **self._retry_policy ): with attempt: - self._stub.ReleaseExecute(request) + self._stub.ReleaseExecute(request, metadata=self._metadata) except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.") @@ -220,7 +223,7 @@ def target() -> None: can_retry=SparkConnectClient.retry_exception, **self._retry_policy ): with attempt: - self._stub.ReleaseExecute(request) + self._stub.ReleaseExecute(request, metadata=self._metadata) except Exception as e: warnings.warn(f"ReleaseExecute failed with exception: {e}.")