diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 0048bc7ffba0d..2011104a19b8a 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -285,6 +285,8 @@ jobs:
lint:
name: Linters, licenses, dependencies and documentation generation
runs-on: ubuntu-20.04
+ container:
+ image: dongjoon/apache-spark-github-action-image:20201025
steps:
- name: Checkout Spark repository
uses: actions/checkout@v2
@@ -315,10 +317,6 @@ jobs:
key: docs-maven-${{ hashFiles('**/pom.xml') }}
restore-keys: |
docs-maven-
- - name: Install Java 8
- uses: actions/setup-java@v1
- with:
- java-version: 8
- name: Install Python 3.6
uses: actions/setup-python@v2
with:
@@ -328,30 +326,24 @@ jobs:
run: |
# TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes.
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
- pip3 install flake8 'sphinx<3.1.0' numpy pydata_sphinx_theme ipython nbsphinx mypy numpydoc
- - name: Install R 4.0
- uses: r-lib/actions/setup-r@v1
- with:
- r-version: 4.0
+ python3.6 -m pip install flake8 'sphinx<3.1.0' numpy pydata_sphinx_theme ipython nbsphinx mypy numpydoc
- name: Install R linter dependencies and SparkR
run: |
- sudo apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev
- sudo Rscript -e "install.packages(c('devtools'), repos='https://cloud.r-project.org/')"
- sudo Rscript -e "devtools::install_github('jimhester/lintr@v2.0.0')"
+ apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev
+ Rscript -e "install.packages(c('devtools'), repos='https://cloud.r-project.org/')"
+ Rscript -e "devtools::install_github('jimhester/lintr@v2.0.0')"
./R/install-dev.sh
- - name: Install Ruby 2.7 for documentation generation
- uses: actions/setup-ruby@v1
- with:
- ruby-version: 2.7
- name: Install dependencies for documentation generation
run: |
# pandoc is required to generate PySpark APIs as well in nbsphinx.
- sudo apt-get install -y libcurl4-openssl-dev pandoc
+ apt-get install -y libcurl4-openssl-dev pandoc
# TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes.
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
- pip install 'sphinx<3.1.0' mkdocs numpy pydata_sphinx_theme ipython nbsphinx numpydoc
+ python3.6 -m pip install 'sphinx<3.1.0' mkdocs numpy pydata_sphinx_theme ipython nbsphinx numpydoc
+ apt-get update -y
+ apt-get install -y ruby ruby-dev
gem install jekyll jekyll-redirect-from rouge
- sudo Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2'), repos='https://cloud.r-project.org/')"
+ Rscript -e "install.packages(c('devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2'), repos='https://cloud.r-project.org/')"
- name: Scala linter
run: ./dev/lint-scala
- name: Java linter
@@ -367,6 +359,8 @@ jobs:
- name: Run documentation build
run: |
cd docs
+ export LC_ALL=C.UTF-8
+ export LANG=C.UTF-8
jekyll build
java-11:
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index c1269a9c91049..5ae596b03d5fe 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -1588,7 +1588,7 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite {
test("SPARK-23365 Don't update target num executors when killing idle executors") {
val clock = new ManualClock()
val manager = createManager(
- createConf(1, 2, 1).set(config.DYN_ALLOCATION_TESTING, false),
+ createConf(1, 2, 1),
clock = clock)
when(client.requestTotalExecutors(any(), any(), any())).thenReturn(true)
@@ -1616,19 +1616,17 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite {
clock.advance(1000)
manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime())
assert(numExecutorsTargetForDefaultProfileId(manager) === 1)
- verify(client, never).killExecutors(any(), any(), any(), any())
+ assert(manager.executorMonitor.executorsPendingToRemove().isEmpty)
// now we cross the idle timeout for executor-1, so we kill it. the really important
// thing here is that we do *not* ask the executor allocation client to adjust the target
// number of executors down
- when(client.killExecutors(Seq("executor-1"), false, false, false))
- .thenReturn(Seq("executor-1"))
clock.advance(3000)
schedule(manager)
assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 1)
assert(numExecutorsTargetForDefaultProfileId(manager) === 1)
// here's the important verify -- we did kill the executors, but did not adjust the target count
- verify(client).killExecutors(Seq("executor-1"), false, false, false)
+ assert(manager.executorMonitor.executorsPendingToRemove() === Set("executor-1"))
}
test("SPARK-26758 check executor target number after idle time out ") {
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index cbb1de53c8896..bd54554baa09d 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -30,6 +30,8 @@ license: |
- In Spark 3.2, `ALTER TABLE .. RENAME TO PARTITION` throws `PartitionAlreadyExistsException` instead of `AnalysisException` for tables from Hive external when the target partition already exists.
+ - In Spark 3.2, script transform default FIELD DELIMIT is `\u0001` for no serde mode. In Spark 3.1 or earlier, the default FIELD DELIMIT is `\t`.
+
## Upgrading from Spark SQL 3.0 to 3.1
- In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.
diff --git a/pom.xml b/pom.xml
index 609c9fc0ab0c3..39ce502ab0e3f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -3277,6 +3277,15 @@
-Wconf:cat=other-match-analysis&site=org.apache.spark.sql.catalyst.catalog.SessionCatalog.lookupFunction.catalogFunction:wv
-Wconf:cat=other-pure-statement&site=org.apache.spark.streaming.util.FileBasedWriteAheadLog.readAll.readFile:wv
-Wconf:cat=other-pure-statement&site=org.apache.spark.scheduler.OutputCommitCoordinatorSuite.<local OutputCommitCoordinatorSuite>.futureAction:wv
+
+ -Wconf:msg=^(?=.*?method|value|type|object|trait|inheritance)(?=.*?deprecated)(?=.*?since 2.13).+$:s
+ -Wconf:msg=^(?=.*?Widening conversion from)(?=.*?is deprecated because it loses precision).+$:s
+ -Wconf:msg=Auto-application to \`\(\)\` is deprecated:s
+ -Wconf:msg=method with a single empty parameter list overrides method without any parameter list:s
+ -Wconf:msg=method without a parameter list overrides a method with a single empty one:s
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java
index 409ab3f5f9335..a7008293a3e19 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsPartitionManagement.java
@@ -139,4 +139,21 @@ Map loadPartitionMetadata(InternalRow ident)
* @return an array of Identifiers for the partitions
*/
InternalRow[] listPartitionIdentifiers(String[] names, InternalRow ident);
+
+ /**
+ * Rename an existing partition of the table.
+ *
+ * @param from an existing partition identifier to rename
+ * @param to new partition identifier
+ * @return true if renaming completes successfully otherwise false
+ * @throws UnsupportedOperationException If partition renaming is not supported
+ * @throws PartitionAlreadyExistsException If the `to` partition exists already
+ * @throws NoSuchPartitionException If the `from` partition does not exist
+ */
+ default boolean renamePartition(InternalRow from, InternalRow to)
+ throws UnsupportedOperationException,
+ PartitionAlreadyExistsException,
+ NoSuchPartitionException {
+ throw new UnsupportedOperationException("Partition renaming is not supported");
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a4dfbe85abfd7..89076fbb9ce0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -215,6 +215,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
case s: SubqueryExpression =>
checkSubqueryExpression(operator, s)
s
+
+ case e: ExpressionWithRandomSeed if !e.seedExpression.foldable =>
+ failAnalysis(
+ s"Input argument to ${e.prettyName} must be a constant.")
}
operator match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala
index 2c2bea6f89d49..84be3f294a6ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala
@@ -52,13 +52,14 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] {
requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames)))
case r @ AlterTableRenamePartition(
- ResolvedTable(_, _, table: SupportsPartitionManagement), from, _) =>
+ ResolvedTable(_, _, table: SupportsPartitionManagement), from, to) =>
val partitionSchema = table.partitionSchema()
- r.copy(from = resolvePartitionSpecs(
+ val Seq(resolvedFrom, resolvedTo) = resolvePartitionSpecs(
table.name,
- Seq(from),
+ Seq(from, to),
partitionSchema,
- requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames)).head)
+ requireExactMatchedPartitionSpec(table.name, _, partitionSchema.fieldNames))
+ r.copy(from = resolvedFrom, to = resolvedTo)
case r @ ShowPartitions(ResolvedTable(_, _, table: SupportsPartitionManagement), partSpecs) =>
r.copy(pattern = resolvePartitionSpecs(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 8978d55b98251..d987704b269f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -51,7 +51,8 @@ case class CsvToStructs(
schema: StructType,
options: Map[String, String],
child: Expression,
- timeZoneId: Option[String] = None)
+ timeZoneId: Option[String] = None,
+ requiredSchema: Option[StructType] = None)
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
@@ -113,7 +114,12 @@ case class CsvToStructs(
val actualSchema =
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
- val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions)
+ val actualRequiredSchema =
+ StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
+ .filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
+ val rawParser = new UnivocityParser(actualSchema,
+ actualRequiredSchema,
+ parsedOptions)
new FailureSafeParser[String](
input => rawParser.parse(input),
mode,
@@ -121,7 +127,7 @@ case class CsvToStructs(
parsedOptions.columnNameOfCorruptRecord)
}
- override def dataType: DataType = nullableSchema
+ override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 630c934f79533..0a4c6e27d51d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
@@ -47,10 +46,8 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
override def seedExpression: Expression = child
@transient protected lazy val seed: Long = seedExpression match {
- case Literal(s, IntegerType) => s.asInstanceOf[Int]
- case Literal(s, LongType) => s.asInstanceOf[Long]
- case _ => throw new AnalysisException(
- s"Input argument to $prettyName must be an integer, long or null literal.")
+ case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
+ case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}
override def nullable: Boolean = false
@@ -64,7 +61,7 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
* Represents the behavior of expressions which have a random seed and can renew the seed.
* Usually the random seed needs to be renewed at each execution under streaming queries.
*/
-trait ExpressionWithRandomSeed {
+trait ExpressionWithRandomSeed extends Expression {
def seedExpression: Expression
def withNewSeed(seed: Long): Expression
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala
new file mode 100644
index 0000000000000..9c32f8be736a4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, StructType}
+
+/**
+ * Simplify redundant csv/json related expressions.
+ *
+ * The optimization includes:
+ * 1. JsonToStructs(StructsToJson(child)) => child.
+ * 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
+ * 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
+ * If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
+ * if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
+ * contains all accessed fields in original CreateNamedStruct.
+ * 4. Prune unnecessary columns from GetStructField + CsvToStructs.
+ */
+object OptimizeCsvJsonExprs extends Rule[LogicalPlan] {
+ private def nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p =>
+ val optimized = if (SQLConf.get.jsonExpressionOptimization) {
+ p.transformExpressions(jsonOptimization)
+ } else {
+ p
+ }
+
+ if (SQLConf.get.csvExpressionOptimization) {
+ optimized.transformExpressions(csvOptimization)
+ } else {
+ optimized
+ }
+ }
+
+ private val jsonOptimization: PartialFunction[Expression, Expression] = {
+ case c: CreateNamedStruct
+ // If we create struct from various fields of the same `JsonToStructs`.
+ if c.valExprs.forall { v =>
+ v.isInstanceOf[GetStructField] &&
+ v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
+ v.children.head.semanticEquals(c.valExprs.head.children.head)
+ } =>
+ val jsonToStructs = c.valExprs.map(_.children.head)
+ val sameFieldName = c.names.zip(c.valExprs).forall {
+ case (name, valExpr: GetStructField) =>
+ name.toString == valExpr.childSchema(valExpr.ordinal).name
+ case _ => false
+ }
+
+ // Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
+ // `JsonToStructs` does not support parsing json with duplicated field names.
+ val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
+
+ // If we create struct from various fields of the same `JsonToStructs` and we don't
+ // alias field names and there is no duplicated field in the struct.
+ if (sameFieldName && !duplicateFields) {
+ val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
+ val nullFields = c.children.grouped(2).flatMap {
+ case Seq(name, value) => Seq(name, Literal(null, value.dataType))
+ }.toSeq
+
+ If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
+ } else {
+ c
+ }
+
+ case jsonToStructs @ JsonToStructs(_, options1,
+ StructsToJson(options2, child, timeZoneId2), timeZoneId1)
+ if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&
+ jsonToStructs.dataType == child.dataType =>
+ // `StructsToJson` only fails when `JacksonGenerator` encounters data types it
+ // cannot convert to JSON. But `StructsToJson.checkInputDataTypes` already
+ // verifies its child's data types is convertible to JSON. But in
+ // `StructsToJson(JsonToStructs(...))` case, we cannot verify input json string
+ // so `JsonToStructs` might throw error in runtime. Thus we cannot optimize
+ // this case similarly.
+ child
+
+ case g @ GetStructField(j @ JsonToStructs(schema: StructType, _, _, _), ordinal, _)
+ if schema.length > 1 =>
+ val prunedSchema = StructType(Seq(schema(ordinal)))
+ g.copy(child = j.copy(schema = prunedSchema), ordinal = 0)
+
+ case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _)
+ if schema.elementType.asInstanceOf[StructType].length > 1 =>
+ val prunedSchema = ArrayType(StructType(Seq(g.field)), g.containsNull)
+ g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1)
+ }
+
+ private val csvOptimization: PartialFunction[Expression, Expression] = {
+ case g @ GetStructField(c @ CsvToStructs(schema: StructType, _, _, _, None), ordinal, _)
+ if schema.length > 1 && c.options.isEmpty && schema(ordinal).name != nameOfCorruptRecord =>
+ // When the parse mode is permissive, and corrupt column is not selected, we can prune here
+ // from `GetStructField`. To be more conservative, it does not optimize when any option
+ // is set.
+ val prunedSchema = StructType(Seq(schema(ordinal)))
+ g.copy(child = c.copy(requiredSchema = Some(prunedSchema)), ordinal = 0)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprs.scala
deleted file mode 100644
index ce86d8cdd4999..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprs.scala
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.optimizer
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, StructType}
-
-/**
- * Simplify redundant json related expressions.
- *
- * The optimization includes:
- * 1. JsonToStructs(StructsToJson(child)) => child.
- * 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
- * 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
- * If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
- * if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
- * contains all accessed fields in original CreateNamedStruct.
- */
-object OptimizeJsonExprs extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p if SQLConf.get.jsonExpressionOptimization => p.transformExpressions {
-
- case c: CreateNamedStruct
- // If we create struct from various fields of the same `JsonToStructs`.
- if c.valExprs.forall { v =>
- v.isInstanceOf[GetStructField] &&
- v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
- v.children.head.semanticEquals(c.valExprs.head.children.head)
- } =>
- val jsonToStructs = c.valExprs.map(_.children.head)
- val sameFieldName = c.names.zip(c.valExprs).forall {
- case (name, valExpr: GetStructField) =>
- name.toString == valExpr.childSchema(valExpr.ordinal).name
- case _ => false
- }
-
- // Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
- // `JsonToStructs` does not support parsing json with duplicated field names.
- val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
-
- // If we create struct from various fields of the same `JsonToStructs` and we don't
- // alias field names and there is no duplicated field in the struct.
- if (sameFieldName && !duplicateFields) {
- val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
- val nullFields = c.children.grouped(2).flatMap {
- case Seq(name, value) => Seq(name, Literal(null, value.dataType))
- }.toSeq
-
- If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
- } else {
- c
- }
-
- case jsonToStructs @ JsonToStructs(_, options1,
- StructsToJson(options2, child, timeZoneId2), timeZoneId1)
- if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&
- jsonToStructs.dataType == child.dataType =>
- // `StructsToJson` only fails when `JacksonGenerator` encounters data types it
- // cannot convert to JSON. But `StructsToJson.checkInputDataTypes` already
- // verifies its child's data types is convertible to JSON. But in
- // `StructsToJson(JsonToStructs(...))` case, we cannot verify input json string
- // so `JsonToStructs` might throw error in runtime. Thus we cannot optimize
- // this case similarly.
- child
-
- case g @ GetStructField(j @ JsonToStructs(schema: StructType, _, _, _), ordinal, _)
- if schema.length > 1 =>
- val prunedSchema = StructType(Seq(schema(ordinal)))
- g.copy(child = j.copy(schema = prunedSchema), ordinal = 0)
-
- case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _)
- if schema.elementType.asInstanceOf[StructType].length > 1 =>
- val prunedSchema = ArrayType(StructType(Seq(g.field)), g.containsNull)
- g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1)
-
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7b9b99bba5574..47260cfb59bb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -114,7 +114,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators,
OptimizeUpdateFields,
SimplifyExtractValueOps,
- OptimizeJsonExprs,
+ OptimizeCsvJsonExprs,
CombineConcats) ++
extendedOperatorOptimizationRules
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index b2625bddeecf4..1b93d514964e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -486,6 +486,11 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(cond, FalseLiteral, l @ Literal(null, _)) if !cond.nullable => And(Not(cond), l)
case If(cond, TrueLiteral, l @ Literal(null, _)) if !cond.nullable => Or(cond, l)
+ case CaseWhen(Seq((cond, TrueLiteral)), Some(FalseLiteral)) =>
+ if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
+ case CaseWhen(Seq((cond, FalseLiteral)), Some(TrueLiteral)) =>
+ if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
+
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
@@ -548,41 +553,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper {
foldables.nonEmpty && others.length < 2
}
+ // Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias.
+ private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match {
+ case _: IsNull | _: IsNotNull => true
+ case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true
+ case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length =>
+ true
+ case _: CastBase => true
+ case _: GetDateField | _: LastDay => true
+ case _: ExtractIntervalPart => true
+ case _: ArraySetLike => true
+ case _: ExtractValue => true
+ case _ => false
+ }
+
+ // Not all BinaryExpression can be pushed into (if / case) branches.
+ private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match {
+ case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true
+ case _: BinaryArithmetic => true
+ case _: BinaryMathExpression => true
+ case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true
+ case _: FindInSet | _: RoundBase => true
+ case _ => false
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case a: Alias => a // Skip an alias.
case u @ UnaryExpression(i @ If(_, trueValue, falseValue))
- if atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
+ if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = u.withNewChildren(Array(trueValue)),
falseValue = u.withNewChildren(Array(falseValue)))
case u @ UnaryExpression(c @ CaseWhen(branches, elseValue))
- if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
+ if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))),
elseValue.map(e => u.withNewChildren(Array(e))))
case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right)
- if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
+ if supportedBinaryExpression(b) && right.foldable &&
+ atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.withNewChildren(Array(trueValue, right)),
falseValue = b.withNewChildren(Array(falseValue, right)))
case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue))
- if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
+ if supportedBinaryExpression(b) && left.foldable &&
+ atMostOneUnfoldable(Seq(trueValue, falseValue)) =>
i.copy(
trueValue = b.withNewChildren(Array(left, trueValue)),
falseValue = b.withNewChildren(Array(left, falseValue)))
case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right)
- if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
+ if supportedBinaryExpression(b) && right.foldable &&
+ atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))),
elseValue.map(e => b.withNewChildren(Array(e, right))))
case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue))
- if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
+ if supportedBinaryExpression(b) && left.foldable &&
+ atMostOneUnfoldable(branches.map(_._2) ++ elseValue) =>
c.copy(
branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))),
elseValue.map(e => b.withNewChildren(Array(left, e))))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index c5707812e44bb..771bb5a1708b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -3845,7 +3845,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
ctx.multipartIdentifier,
"ALTER TABLE ... RENAME TO PARTITION"),
UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(ctx.from)),
- visitNonOptionalPartitionSpec(ctx.to))
+ UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(ctx.to)))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 02fb3a86db5d5..c51291d370c80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -694,9 +694,11 @@ case class AlterTableDropPartition(
case class AlterTableRenamePartition(
child: LogicalPlan,
from: PartitionSpec,
- to: TablePartitionSpec) extends Command {
+ to: PartitionSpec) extends Command {
override lazy val resolved: Boolean =
- childrenResolved && from.isInstanceOf[ResolvedPartitionSpec]
+ childrenResolved &&
+ from.isInstanceOf[ResolvedPartitionSpec] &&
+ to.isInstanceOf[ResolvedPartitionSpec]
override def children: Seq[LogicalPlan] = child :: Nil
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d14d136a81e7f..6fcab887dd6af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1631,6 +1631,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val CSV_EXPRESSION_OPTIMIZATION =
+ buildConf("spark.sql.optimizer.enableCsvExpressionOptimization")
+ .doc("Whether to optimize CSV expressions in SQL optimizer. It includes pruning " +
+ "unnecessary columns from from_csv.")
+ .version("3.2.0")
+ .booleanConf
+ .createWithDefault(true)
+
val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion")
.internal()
.doc("Whether to delete the expired log files in file stream sink.")
@@ -3489,6 +3497,8 @@ class SQLConf extends Serializable with Logging {
def jsonExpressionOptimization: Boolean = getConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION)
+ def csvExpressionOptimization: Boolean = getConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION)
+
def parallelFileListingInStatsComputation: Boolean =
getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 44128c4419951..004d577c7ad52 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -700,4 +700,17 @@ class AnalysisErrorSuite extends AnalysisTest {
UnresolvedRelation(TableIdentifier("t", Option("nonexist")))))))
assertAnalysisError(plan, "Table or view not found:" :: Nil)
}
+
+ test("SPARK-33909: Check rand functions seed is legal at analyer side") {
+ Seq(Rand("a".attr), Randn("a".attr)).foreach { r =>
+ val plan = Project(Seq(r.as("r")), testRelation)
+ assertAnalysisError(plan,
+ s"Input argument to ${r.prettyName} must be a constant." :: Nil)
+ }
+ Seq(Rand(1.0), Rand("1"), Randn("a")).foreach { r =>
+ val plan = Project(Seq(r.as("r")), testRelation)
+ assertAnalysisError(plan,
+ s"data type mismatch: argument 1 requires (int or bigint) type" :: Nil)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
index 729a1e9f06ca5..d4b85b036b64c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
@@ -60,7 +60,7 @@ class LeftSemiPushdownSuite extends PlanTest {
test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
val originalQuery = testRelation
- .select(Rand('a), 'b, 'c)
+ .select(Rand(1), 'b, 'c)
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
val optimized = Optimize.execute(originalQuery.analyze)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala
new file mode 100644
index 0000000000000..9b208cf2b57c4
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+class OptimizeCsvExprsSuite extends PlanTest with ExpressionEvalHelper {
+
+ private var csvExpressionOptimizeEnabled: Boolean = _
+ protected override def beforeAll(): Unit = {
+ csvExpressionOptimizeEnabled = SQLConf.get.csvExpressionOptimization
+ }
+
+ protected override def afterAll(): Unit = {
+ SQLConf.get.setConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION, csvExpressionOptimizeEnabled)
+ }
+
+ object Optimizer extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("Csv optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil
+ }
+
+ val schema = StructType.fromDDL("a int, b int")
+
+ private val csvAttr = 'csv.string
+ private val testRelation = LocalRelation(csvAttr)
+
+ test("SPARK-32968: prune unnecessary columns from GetStructField + from_csv") {
+ val options = Map.empty[String, String]
+
+ val query1 = testRelation
+ .select(GetStructField(CsvToStructs(schema, options, 'csv), 0))
+ val optimized1 = Optimizer.execute(query1.analyze)
+
+ val prunedSchema1 = StructType.fromDDL("a int")
+ val expected1 = testRelation
+ .select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema1)), 0))
+ .analyze
+ comparePlans(optimized1, expected1)
+
+ val query2 = testRelation
+ .select(GetStructField(CsvToStructs(schema, options, 'csv), 1))
+ val optimized2 = Optimizer.execute(query2.analyze)
+
+ val prunedSchema2 = StructType.fromDDL("b int")
+ val expected2 = testRelation
+ .select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema2)), 0))
+ .analyze
+ comparePlans(optimized2, expected2)
+ }
+
+ test("SPARK-32968: don't prune columns if options is not empty") {
+ val options = Map("mode" -> "failfast")
+
+ val query = testRelation
+ .select(GetStructField(CsvToStructs(schema, options, 'csv), 0))
+ val optimized = Optimizer.execute(query.analyze)
+
+ val expected = query.analyze
+ comparePlans(optimized, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala
index 4129a37eb69a2..05d47706ba297 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala
@@ -39,7 +39,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper {
}
object Optimizer extends RuleExecutor[LogicalPlan] {
- val batches = Batch("Json optimization", FixedPoint(10), OptimizeJsonExprs) :: Nil
+ val batches = Batch("Json optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil
}
val schema = StructType.fromDDL("a int, b int")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala
index 7c9a67d7554e2..0d5218ac629e3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala
@@ -141,7 +141,7 @@ class PushFoldableIntoBranchesSuite
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal(1))), Some(Literal(2)))
assert(!nonDeterministic.deterministic)
assertEquivalent(EqualTo(nonDeterministic, Literal(2)),
- CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(TrueLiteral)))
+ GreaterThanOrEqual(Rand(1), Literal(0.5)))
assertEquivalent(EqualTo(nonDeterministic, Literal(3)),
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), FalseLiteral)), Some(FalseLiteral)))
@@ -269,4 +269,13 @@ class PushFoldableIntoBranchesSuite
Literal.create(null, BooleanType))
}
}
+
+ test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
+ assertEquivalent(
+ EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)),
+ 'a > 10 <=> TrueLiteral)
+ assertEquivalent(
+ EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)),
+ Not('a > 10 <=> TrueLiteral))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index 317984eba2261..f3edd70bcfb12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -243,4 +243,40 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
Literal.create(null, IntegerType))
}
}
+
+ test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {
+ // verify the boolean equivalence of all transformations involved
+ val fields = Seq(
+ 'cond.boolean.notNull,
+ 'cond_nullable.boolean,
+ 'a.boolean,
+ 'b.boolean
+ )
+ val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) }
+
+ val exprs = Seq(
+ // actual expressions of the transformations: original -> transformed
+ CaseWhen(Seq((cond, TrueLiteral)), FalseLiteral) -> cond,
+ CaseWhen(Seq((cond, FalseLiteral)), TrueLiteral) -> !cond,
+ CaseWhen(Seq((cond_nullable, TrueLiteral)), FalseLiteral) -> (cond_nullable <=> true),
+ CaseWhen(Seq((cond_nullable, FalseLiteral)), TrueLiteral) -> (!(cond_nullable <=> true)))
+
+ // check plans
+ for ((originalExpr, expectedExpr) <- exprs) {
+ assertEquivalent(originalExpr, expectedExpr)
+ }
+
+ // check evaluation
+ val binaryBooleanValues = Seq(true, false)
+ val ternaryBooleanValues = Seq(true, false, null)
+ for (condVal <- binaryBooleanValues;
+ condNullableVal <- ternaryBooleanValues;
+ aVal <- ternaryBooleanValues;
+ bVal <- ternaryBooleanValues;
+ (originalExpr, expectedExpr) <- exprs) {
+ val inputRow = create_row(condVal, condNullableVal, aVal, bVal)
+ val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
+ checkEvaluation(originalExpr, optimizedVal, inputRow)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala
index e29c78c59f769..a3d610af2c06d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryPartitionTable.scala
@@ -84,7 +84,7 @@ class InMemoryPartitionTable(
}
override protected def addPartitionKey(key: Seq[Any]): Unit = {
- memoryTablePartitions.put(InternalRow.fromSeq(key), Map.empty[String, String].asJava)
+ memoryTablePartitions.putIfAbsent(InternalRow.fromSeq(key), Map.empty[String, String].asJava)
}
override def listPartitionIdentifiers(
@@ -107,4 +107,17 @@ class InMemoryPartitionTable(
currentRow == ident
}.toArray
}
+
+ override def renamePartition(from: InternalRow, to: InternalRow): Boolean = {
+ if (memoryTablePartitions.containsKey(to)) {
+ throw new PartitionAlreadyExistsException(name, to, partitionSchema)
+ } else {
+ val partValue = memoryTablePartitions.remove(from)
+ if (partValue == null) {
+ throw new NoSuchPartitionException(name, from, partitionSchema)
+ }
+ memoryTablePartitions.put(to, partValue) == null &&
+ renamePartitionKey(partitionSchema, from.toSeq(schema), to.toSeq(schema))
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
index c4c5835d9d1f5..201d67a815bea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
@@ -165,6 +165,28 @@ class InMemoryTable(
protected def addPartitionKey(key: Seq[Any]): Unit = {}
+ protected def renamePartitionKey(
+ partitionSchema: StructType,
+ from: Seq[Any],
+ to: Seq[Any]): Boolean = {
+ val rows = dataMap.remove(from).getOrElse(new BufferedRows(from.mkString("/")))
+ val newRows = new BufferedRows(to.mkString("/"))
+ rows.rows.foreach { r =>
+ val newRow = new GenericInternalRow(r.numFields)
+ for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType))
+ for (i <- 0 until partitionSchema.length) {
+ val j = schema.fieldIndex(partitionSchema(i).name)
+ newRow.update(j, to(i))
+ }
+ newRows.withRow(newRow)
+ }
+ dataMap.put(to, newRows).foreach { _ =>
+ throw new IllegalStateException(
+ s"The ${to.mkString("[", ", ", "]")} partition exists already")
+ }
+ true
+ }
+
def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {
data.foreach(_.rows.foreach { row =>
val key = getKey(row)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala
index 31494c7c2dd50..99441c81d9add 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/SupportsPartitionManagementSuite.scala
@@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException}
import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -214,4 +215,22 @@ class SupportsPartitionManagementSuite extends SparkFunSuite {
}.getMessage
assert(errMsg.contains("The identifier might not refer to one partition"))
}
+
+ test("renamePartition") {
+ val partTable = createMultiPartTable()
+
+ val errMsg1 = intercept[PartitionAlreadyExistsException] {
+ partTable.renamePartition(InternalRow(0, "abc"), InternalRow(1, "abc"))
+ }.getMessage
+ assert(errMsg1.contains("Partition already exists"))
+
+ val newPart = InternalRow(2, "xyz")
+ val errMsg2 = intercept[NoSuchPartitionException] {
+ partTable.renamePartition(newPart, InternalRow(3, "abc"))
+ }.getMessage
+ assert(errMsg2.contains("Partition not found"))
+
+ assert(partTable.renamePartition(InternalRow(0, "abc"), newPart))
+ assert(partTable.partitionExists(newPart))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 925c7741eefe3..dec1300d66f35 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -444,11 +444,10 @@ class ResolveSessionCatalog(
ifNotExists)
case AlterTableRenamePartition(
- ResolvedV1TableIdentifier(ident), UnresolvedPartitionSpec(from, _), to) =>
- AlterTableRenamePartitionCommand(
- ident.asTableIdentifier,
- from,
- to)
+ ResolvedV1TableIdentifier(ident),
+ UnresolvedPartitionSpec(from, _),
+ UnresolvedPartitionSpec(to, _)) =>
+ AlterTableRenamePartitionCommand(ident.asTableIdentifier, from, to)
case AlterTableDropPartition(
ResolvedV1TableIdentifier(ident), specs, ifExists, purge) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala
index b924a4ac3b856..47d9979a26c29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala
@@ -367,7 +367,7 @@ case class ScriptTransformationIOSchema(
object ScriptTransformationIOSchema {
val defaultFormat = Map(
- ("TOK_TABLEROWFORMATFIELD", "\t"),
+ ("TOK_TABLEROWFORMATFIELD", "\u0001"),
("TOK_TABLEROWFORMATLINES", "\n"),
("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"),
("TOK_TABLEROWFORMATMAPKEYS", "\u0003")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableRenamePartitionExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableRenamePartitionExec.scala
new file mode 100644
index 0000000000000..38b83e3ad74e7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterTableRenamePartitionExec.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.ResolvedPartitionSpec
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement
+
+/**
+ * Physical plan node for renaming a table partition.
+ */
+case class AlterTableRenamePartitionExec(
+ table: SupportsPartitionManagement,
+ from: ResolvedPartitionSpec,
+ to: ResolvedPartitionSpec) extends V2CommandExec {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override protected def run(): Seq[InternalRow] = {
+ table.renamePartition(from.ident, to.ident)
+ Seq.empty
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 4667bb7cca998..2674aaf4f2e88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -352,9 +352,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
AlterTableDropPartitionExec(
table, parts.asResolvedPartitionSpecs, ignoreIfNotExists, purge) :: Nil
- case AlterTableRenamePartition(_: ResolvedTable, _: ResolvedPartitionSpec, _) =>
- throw new AnalysisException(
- "ALTER TABLE ... RENAME TO PARTITION is not supported for v2 tables.")
+ case AlterTableRenamePartition(
+ ResolvedTable(_, _, table: SupportsPartitionManagement), from, to) =>
+ AlterTableRenamePartitionExec(
+ table,
+ Seq(from).asResolvedPartitionSpecs.head,
+ Seq(to).asResolvedPartitionSpecs.head) :: Nil
case AlterTableRecoverPartitions(_: ResolvedTable) =>
throw new AnalysisException(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
index abccaf19084b2..16b92d6d11c91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala
@@ -250,4 +250,52 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
| """.stripMargin)
checkAnswer(toDF("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), toDF("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]"))
}
+
+ test("SPARK-32968: Pruning csv field should not change result") {
+ Seq("true", "false").foreach { enabled =>
+ withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
+ val df1 = sparkContext.parallelize(Seq("a,b")).toDF("csv")
+ .selectExpr("from_csv(csv, 'a string, b string', map('mode', 'failfast')) as parsed")
+ checkAnswer(df1.selectExpr("parsed.a"), Seq(Row("a")))
+ checkAnswer(df1.selectExpr("parsed.b"), Seq(Row("b")))
+
+ val df2 = sparkContext.parallelize(Seq("a,b")).toDF("csv")
+ .selectExpr("from_csv(csv, 'a string, b string') as parsed")
+ checkAnswer(df2.selectExpr("parsed.a"), Seq(Row("a")))
+ checkAnswer(df2.selectExpr("parsed.b"), Seq(Row("b")))
+ }
+ }
+ }
+
+ test("SPARK-32968: bad csv input with csv pruning optimization") {
+ Seq("true", "false").foreach { enabled =>
+ withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
+ val df = sparkContext.parallelize(Seq("1,\u0001\u0000\u0001234")).toDF("csv")
+ .selectExpr("from_csv(csv, 'a int, b int', map('mode', 'failfast')) as parsed")
+
+ val err1 = intercept[SparkException] {
+ df.selectExpr("parsed.a").collect
+ }
+
+ val err2 = intercept[SparkException] {
+ df.selectExpr("parsed.b").collect
+ }
+
+ assert(err1.getMessage.contains("Malformed records are detected in record parsing"))
+ assert(err2.getMessage.contains("Malformed records are detected in record parsing"))
+ }
+ }
+ }
+
+ test("SPARK-32968: csv pruning optimization with corrupt record field") {
+ Seq("true", "false").foreach { enabled =>
+ withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
+ val df = sparkContext.parallelize(Seq("a,b,c,d")).toDF("csv")
+ .selectExpr("from_csv(csv, 'a string, b string, _corrupt_record string') as parsed")
+ .selectExpr("parsed._corrupt_record")
+
+ checkAnswer(df, Seq(Row("a,b,c,d")))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index ed4ea567e4f65..b8d58217efa6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -2581,6 +2581,27 @@ class DataSourceV2SQLSuite
"ALTER VIEW ... AS")
}
+ test("SPARK-33924: INSERT INTO .. PARTITION preserves the partition location") {
+ val t = "testpart.ns1.ns2.tbl"
+ withTable(t) {
+ sql(s"""
+ |CREATE TABLE $t (id bigint, city string, data string)
+ |USING foo
+ |PARTITIONED BY (id, city)""".stripMargin)
+ val partTable = catalog("testpart").asTableCatalog
+ .loadTable(Identifier.of(Array("ns1", "ns2"), "tbl")).asInstanceOf[InMemoryPartitionTable]
+
+ val loc = "partition_location"
+ sql(s"ALTER TABLE $t ADD PARTITION (id = 1, city = 'NY') LOCATION '$loc'")
+
+ val ident = InternalRow.fromSeq(Seq(1, UTF8String.fromString("NY")))
+ assert(partTable.loadPartitionMetadata(ident).get("location") === loc)
+
+ sql(s"INSERT INTO $t PARTITION(id = 1, city = 'NY') SELECT 'abc'")
+ assert(partTable.loadPartitionMetadata(ident).get("location") === loc)
+ }
+ }
+
private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = {
val e = intercept[AnalysisException] {
sql(s"$sqlCommand $sqlParams")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
index 2e45fd7242c32..4d6faae983514 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, TaskContext, TestUtils}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
@@ -123,7 +124,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
s"""
|SELECT
|TRANSFORM(a, b, c, d, e)
- |USING 'python $scriptFilePath' AS (a, b, c, d, e)
+ | ROW FORMAT DELIMITED
+ | FIELDS TERMINATED BY '\t'
+ | USING 'python $scriptFilePath' AS (a, b, c, d, e)
+ | ROW FORMAT DELIMITED
+ | FIELDS TERMINATED BY '\t'
|FROM v
""".stripMargin)
@@ -538,6 +543,28 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
}.getMessage
assert(e.contains("Number of levels of nesting supported for Spark SQL" +
" script transform is 7 Unable to work with level 8"))
+ test("SPARK-33930: Script Transform default FIELD DELIMIT should be \u0001 (no serde)") {
+ withTempView("v") {
+ val df = Seq(
+ (1, 2, 3),
+ (2, 3, 4),
+ (3, 4, 5)
+ ).toDF("a", "b", "c")
+ df.createTempView("v")
+
+ checkAnswer(
+ sql(
+ s"""
+ |SELECT TRANSFORM(a, b, c)
+ | ROW FORMAT DELIMITED
+ | USING 'cat' AS (a)
+ | ROW FORMAT DELIMITED
+ | FIELDS TERMINATED BY '&'
+ |FROM v
+ """.stripMargin), identity,
+ Row("1\u00012\u00013") ::
+ Row("2\u00013\u00014") ::
+ Row("3\u00014\u00015") :: Nil)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala
index aa0668ccaaf53..2705adb8b3c67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command
import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException
-import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.internal.SQLConf
/**
@@ -39,8 +38,6 @@ import org.apache.spark.sql.internal.SQLConf
trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils {
override val command = "ALTER TABLE .. ADD PARTITION"
- protected def checkLocation(t: String, spec: TablePartitionSpec, expected: String): Unit
-
test("one partition") {
withNamespaceAndTable("ns", "tbl") { t =>
sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing PARTITIONED BY (id)")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala
index db6506c85bcec..c9a6732796729 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionParserSuite.scala
@@ -32,7 +32,7 @@ class AlterTableRenamePartitionParserSuite extends AnalysisTest with SharedSpark
val expected = AlterTableRenamePartition(
UnresolvedTable(Seq("a", "b", "c"), "ALTER TABLE ... RENAME TO PARTITION"),
UnresolvedPartitionSpec(Map("ds" -> "2017-06-10")),
- Map("ds" -> "2018-06-10"))
+ UnresolvedPartitionSpec(Map("ds" -> "2018-06-10")))
comparePlans(parsed, expected)
}
@@ -45,7 +45,7 @@ class AlterTableRenamePartitionParserSuite extends AnalysisTest with SharedSpark
val expected = AlterTableRenamePartition(
UnresolvedTable(Seq("table_name"), "ALTER TABLE ... RENAME TO PARTITION"),
UnresolvedPartitionSpec(Map("dt" -> "2008-08-08", "country" -> "us")),
- Map("dt" -> "2008-09-09", "country" -> "uk"))
+ UnresolvedPartitionSpec(Map("dt" -> "2008-09-09", "country" -> "uk")))
comparePlans(parsed, expected)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala
index 40c167ce424a0..58055262d3f11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenamePartitionSuiteBase.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.command
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException}
+import org.apache.spark.sql.internal.SQLConf
/**
* This base suite contains unified tests for the `ALTER TABLE .. RENAME PARTITION` command that
@@ -35,4 +37,130 @@ import org.apache.spark.sql.QueryTest
*/
trait AlterTableRenamePartitionSuiteBase extends QueryTest with DDLCommandTestUtils {
override val command = "ALTER TABLE .. RENAME PARTITION"
+
+ protected def createSinglePartTable(t: String): Unit = {
+ sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing PARTITIONED BY (id)")
+ sql(s"INSERT INTO $t PARTITION (id = 1) SELECT 'abc'")
+ }
+
+ test("rename without explicitly specifying database") {
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> catalog) {
+ createSinglePartTable("t")
+ checkPartitions("t", Map("id" -> "1"))
+
+ sql(s"ALTER TABLE t PARTITION (id = 1) RENAME TO PARTITION (id = 2)")
+ checkPartitions("t", Map("id" -> "2"))
+ checkAnswer(sql(s"SELECT id, data FROM t"), Row(2, "abc"))
+ }
+ }
+
+ test("table to alter does not exist") {
+ withNamespace(s"$catalog.ns") {
+ sql(s"CREATE NAMESPACE $catalog.ns")
+ val errMsg = intercept[AnalysisException] {
+ sql(s"ALTER TABLE $catalog.ns.no_tbl PARTITION (id=1) RENAME TO PARTITION (id=2)")
+ }.getMessage
+ assert(errMsg.contains("Table not found"))
+ }
+ }
+
+ test("partition to rename does not exist") {
+ withNamespaceAndTable("ns", "tbl") { t =>
+ createSinglePartTable(t)
+ checkPartitions(t, Map("id" -> "1"))
+ val errMsg = intercept[NoSuchPartitionException] {
+ sql(s"ALTER TABLE $t PARTITION (id = 3) RENAME TO PARTITION (id = 2)")
+ }.getMessage
+ assert(errMsg.contains("Partition not found in table"))
+ }
+ }
+
+ test("target partition exists") {
+ withNamespaceAndTable("ns", "tbl") { t =>
+ createSinglePartTable(t)
+ sql(s"INSERT INTO $t PARTITION (id = 2) SELECT 'def'")
+ checkPartitions(t, Map("id" -> "1"), Map("id" -> "2"))
+ val errMsg = intercept[PartitionAlreadyExistsException] {
+ sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)")
+ }.getMessage
+ assert(errMsg.contains("Partition already exists"))
+ }
+ }
+
+ test("single part partition") {
+ withNamespaceAndTable("ns", "tbl") { t =>
+ createSinglePartTable(t)
+ checkPartitions(t, Map("id" -> "1"))
+
+ sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)")
+ checkPartitions(t, Map("id" -> "2"))
+ checkAnswer(sql(s"SELECT id, data FROM $t"), Row(2, "abc"))
+ }
+ }
+
+ test("multi part partition") {
+ withNamespaceAndTable("ns", "tbl") { t =>
+ createWideTable(t)
+ checkPartitions(t,
+ Map(
+ "year" -> "2016",
+ "month" -> "3",
+ "hour" -> "10",
+ "minute" -> "10",
+ "sec" -> "10",
+ "extra" -> "1"),
+ Map(
+ "year" -> "2016",
+ "month" -> "4",
+ "hour" -> "10",
+ "minute" -> "10",
+ "sec" -> "10",
+ "extra" -> "1"))
+
+ sql(s"""
+ |ALTER TABLE $t
+ |PARTITION (
+ | year = 2016, month = 3, hour = 10, minute = 10, sec = 10, extra = 1
+ |) RENAME TO PARTITION (
+ | year = 2016, month = 3, hour = 10, minute = 10, sec = 123, extra = 1
+ |)""".stripMargin)
+ checkPartitions(t,
+ Map(
+ "year" -> "2016",
+ "month" -> "3",
+ "hour" -> "10",
+ "minute" -> "10",
+ "sec" -> "123",
+ "extra" -> "1"),
+ Map(
+ "year" -> "2016",
+ "month" -> "4",
+ "hour" -> "10",
+ "minute" -> "10",
+ "sec" -> "10",
+ "extra" -> "1"))
+ checkAnswer(sql(s"SELECT month, sec, price FROM $t"), Row(3, 123, 3))
+ }
+ }
+
+ test("partition spec in RENAME PARTITION should be case insensitive") {
+ withNamespaceAndTable("ns", "tbl") { t =>
+ createSinglePartTable(t)
+ checkPartitions(t, Map("id" -> "1"))
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val errMsg = intercept[AnalysisException] {
+ sql(s"ALTER TABLE $t PARTITION (ID = 1) RENAME TO PARTITION (id = 2)")
+ }.getMessage
+ assert(errMsg.contains("ID is not a valid partition column"))
+ checkPartitions(t, Map("id" -> "1"))
+ }
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ sql(s"ALTER TABLE $t PARTITION (ID = 1) RENAME TO PARTITION (id = 2)")
+ checkPartitions(t, Map("id" -> "2"))
+ checkAnswer(sql(s"SELECT id, data FROM $t"), Row(2, "abc"))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala
index a613978ce375a..f4b84d8ee0059 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandTestUtils.scala
@@ -21,6 +21,7 @@ import org.scalactic.source.Position
import org.scalatest.Tag
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.test.SQLTestUtils
@@ -88,4 +89,6 @@ trait DDLCommandTestUtils extends SQLTestUtils {
|ADD PARTITION(year = 2016, month = 4, hour = 10, minute = 10, sec = 10, extra = 1)
|""".stripMargin)
}
+
+ protected def checkLocation(t: String, spec: TablePartitionSpec, expected: String): Unit
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala
index 808eab8340524..b3c118def70b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.command.v1
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.execution.command
/**
@@ -32,21 +31,6 @@ import org.apache.spark.sql.execution.command
* `org.apache.spark.sql.hive.execution.command.AlterTableAddPartitionSuite`
*/
trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuiteBase {
- override protected def checkLocation(
- t: String,
- spec: TablePartitionSpec,
- expected: String): Unit = {
- val tablePath = t.split('.')
- val tableName = tablePath.last
- val ns = tablePath.init.mkString(".")
- val partSpec = spec.map { case (key, value) => s"$key = $value"}.mkString(", ")
- val information = sql(s"SHOW TABLE EXTENDED IN $ns LIKE '$tableName' PARTITION($partSpec)")
- .select("information")
- .first().getString(0)
- val location = information.split("\\r?\\n").filter(_.startsWith("Location:")).head
- assert(location.endsWith(expected))
- }
-
test("empty string as partition value") {
withNamespaceAndTable("ns", "tbl") { t =>
sql(s"CREATE TABLE $t (col1 INT, p1 STRING) $defaultUsing PARTITIONED BY (p1)")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenamePartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenamePartitionSuite.scala
index d923886fbdb9a..bde77106a3ab7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenamePartitionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableRenamePartitionSuite.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql.execution.command.v1
-import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, PartitionAlreadyExistsException}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.command
-import org.apache.spark.sql.internal.SQLConf
/**
* This base suite contains unified tests for the `ALTER TABLE .. RENAME PARTITION` command that
@@ -33,143 +31,19 @@ import org.apache.spark.sql.internal.SQLConf
* `org.apache.spark.sql.hive.execution.command.AlterTableRenamePartitionSuite`
*/
trait AlterTableRenamePartitionSuiteBase extends command.AlterTableRenamePartitionSuiteBase {
- protected def createSinglePartTable(t: String): Unit = {
- sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing PARTITIONED BY (id)")
- sql(s"INSERT INTO $t PARTITION (id = 1) SELECT 'abc'")
- }
-
- test("rename without explicitly specifying database") {
- val t = "tbl"
- withTable(t) {
- createSinglePartTable(t)
- checkPartitions(t, Map("id" -> "1"))
-
- sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)")
- checkPartitions(t, Map("id" -> "2"))
- checkAnswer(sql(s"SELECT id, data FROM $t"), Row(2, "abc"))
- }
- }
-
- test("table to alter does not exist") {
- withNamespace(s"$catalog.ns") {
- sql(s"CREATE NAMESPACE $catalog.ns")
- val errMsg = intercept[AnalysisException] {
- sql(s"ALTER TABLE $catalog.ns.no_tbl PARTITION (id=1) RENAME TO PARTITION (id=2)")
- }.getMessage
- assert(errMsg.contains("Table not found"))
- }
- }
-
- test("partition to rename does not exist") {
- withNamespaceAndTable("ns", "tbl") { t =>
- createSinglePartTable(t)
- checkPartitions(t, Map("id" -> "1"))
- val errMsg = intercept[NoSuchPartitionException] {
- sql(s"ALTER TABLE $t PARTITION (id = 3) RENAME TO PARTITION (id = 2)")
- }.getMessage
- assert(errMsg.contains("Partition not found in table"))
- }
- }
-
- test("target partition exists") {
- withNamespaceAndTable("ns", "tbl") { t =>
- createSinglePartTable(t)
- sql(s"INSERT INTO $t PARTITION (id = 2) SELECT 'def'")
- checkPartitions(t, Map("id" -> "1"), Map("id" -> "2"))
- val errMsg = intercept[PartitionAlreadyExistsException] {
- sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)")
- }.getMessage
- assert(errMsg.contains("Partition already exists"))
- }
- }
-
- test("single part partition") {
- withNamespaceAndTable("ns", "tbl") { t =>
- createSinglePartTable(t)
- checkPartitions(t, Map("id" -> "1"))
-
- sql(s"ALTER TABLE $t PARTITION (id = 1) RENAME TO PARTITION (id = 2)")
- checkPartitions(t, Map("id" -> "2"))
- checkAnswer(sql(s"SELECT id, data FROM $t"), Row(2, "abc"))
- }
- }
-
- test("multi part partition") {
- withNamespaceAndTable("ns", "tbl") { t =>
- createWideTable(t)
- checkPartitions(t,
- Map(
- "year" -> "2016",
- "month" -> "3",
- "hour" -> "10",
- "minute" -> "10",
- "sec" -> "10",
- "extra" -> "1"),
- Map(
- "year" -> "2016",
- "month" -> "4",
- "hour" -> "10",
- "minute" -> "10",
- "sec" -> "10",
- "extra" -> "1"))
-
- sql(s"""
- |ALTER TABLE $t
- |PARTITION (
- | year = 2016, month = 3, hour = 10, minute = 10, sec = 10, extra = 1
- |) RENAME TO PARTITION (
- | year = 2016, month = 3, hour = 10, minute = 10, sec = 123, extra = 1
- |)""".stripMargin)
- checkPartitions(t,
- Map(
- "year" -> "2016",
- "month" -> "3",
- "hour" -> "10",
- "minute" -> "10",
- "sec" -> "123",
- "extra" -> "1"),
- Map(
- "year" -> "2016",
- "month" -> "4",
- "hour" -> "10",
- "minute" -> "10",
- "sec" -> "10",
- "extra" -> "1"))
- checkAnswer(sql(s"SELECT month, sec, price FROM $t"), Row(3, 123, 3))
- }
- }
-
test("with location") {
withNamespaceAndTable("ns", "tbl") { t =>
createSinglePartTable(t)
sql(s"ALTER TABLE $t ADD PARTITION (id = 2) LOCATION 'loc1'")
sql(s"INSERT INTO $t PARTITION (id = 2) SELECT 'def'")
checkPartitions(t, Map("id" -> "1"), Map("id" -> "2"))
+ checkLocation(t, Map("id" -> "2"), "loc1")
sql(s"ALTER TABLE $t PARTITION (id = 2) RENAME TO PARTITION (id = 3)")
checkPartitions(t, Map("id" -> "1"), Map("id" -> "3"))
- checkAnswer(sql(s"SELECT id, data FROM $t"), Seq(Row(1, "abc"), Row(3, "def")))
- }
- }
-
- test("partition spec in RENAME PARTITION should be case insensitive") {
- withNamespaceAndTable("ns", "tbl") { t =>
- createSinglePartTable(t)
- checkPartitions(t, Map("id" -> "1"))
-
- withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
- val errMsg = intercept[AnalysisException] {
- sql(s"ALTER TABLE $t PARTITION (ID = 1) RENAME TO PARTITION (id = 2)")
- }.getMessage
- assert(errMsg.contains("ID is not a valid partition column"))
- checkPartitions(t, Map("id" -> "1"))
- }
-
- withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
- sql(s"ALTER TABLE $t PARTITION (ID = 1) RENAME TO PARTITION (id = 2)")
- checkPartitions(t, Map("id" -> "2"))
- checkAnswer(sql(s"SELECT id, data FROM $t"), Row(2, "abc"))
- }
+ // V1 catalogs rename the partition location of managed tables
+ checkLocation(t, Map("id" -> "3"), "id=3")
+ checkAnswer(sql(s"SELECT id, data FROM $t WHERE id = 3"), Row(3, "def"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CommandSuiteBase.scala
index c4ecf1c98bb6e..80c552de567ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CommandSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CommandSuiteBase.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.command.v1
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.test.SharedSparkSession
@@ -30,4 +31,20 @@ trait CommandSuiteBase extends SharedSparkSession {
def version: String = "V1" // The prefix is added to test names
def catalog: String = CatalogManager.SESSION_CATALOG_NAME
def defaultUsing: String = "USING parquet" // The clause is used in creating tables under testing
+
+ // TODO(SPARK-33393): Move this to `DDLCommandTestUtils`
+ def checkLocation(
+ t: String,
+ spec: TablePartitionSpec,
+ expected: String): Unit = {
+ val tablePath = t.split('.')
+ val tableName = tablePath.last
+ val ns = tablePath.init.mkString(".")
+ val partSpec = spec.map { case (key, value) => s"$key = $value"}.mkString(", ")
+ val information = sql(s"SHOW TABLE EXTENDED IN $ns LIKE '$tableName' PARTITION($partSpec)")
+ .select("information")
+ .first().getString(0)
+ val location = information.split("\\r?\\n").filter(_.startsWith("Location:")).head
+ assert(location.endsWith(expected))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala
index 0f0f8fa389321..65494a7266756 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala
@@ -18,10 +18,6 @@
package org.apache.spark.sql.execution.command.v2
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec
-import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.connector.InMemoryPartitionTable
-import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier}
import org.apache.spark.sql.execution.command
/**
@@ -31,29 +27,6 @@ import org.apache.spark.sql.execution.command
class AlterTableAddPartitionSuite
extends command.AlterTableAddPartitionSuiteBase
with CommandSuiteBase {
-
- import CatalogV2Implicits._
-
- override protected def checkLocation(
- t: String,
- spec: TablePartitionSpec,
- expected: String): Unit = {
- val tablePath = t.split('.')
- val catalogName = tablePath.head
- val namespaceWithTable = tablePath.tail
- val namespaces = namespaceWithTable.init
- val tableName = namespaceWithTable.last
- val catalogPlugin = spark.sessionState.catalogManager.catalog(catalogName)
- val partTable = catalogPlugin.asTableCatalog
- .loadTable(Identifier.of(namespaces, tableName))
- .asInstanceOf[InMemoryPartitionTable]
- val ident = ResolvePartitionSpec.convertToPartIdent(spec, partTable.partitionSchema.fields)
- val partMetadata = partTable.loadPartitionMetadata(ident)
-
- assert(partMetadata.containsKey("location"))
- assert(partMetadata.get("location") === expected)
- }
-
test("SPARK-33650: add partition into a table which doesn't support partition management") {
withNamespaceAndTable("ns", "tbl", s"non_part_$catalog") { t =>
sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenamePartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenamePartitionSuite.scala
index d1c252adde369..bb06818da48b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenamePartitionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableRenamePartitionSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.command.v2
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.command
/**
@@ -28,14 +28,20 @@ class AlterTableRenamePartitionSuite
extends command.AlterTableRenamePartitionSuiteBase
with CommandSuiteBase {
- // TODO(SPARK-33859): Support V2 ALTER TABLE .. RENAME PARTITION
- test("single part partition") {
+ test("with location") {
withNamespaceAndTable("ns", "tbl") { t =>
- sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing PARTITIONED BY (id)")
- val errMsg = intercept[AnalysisException] {
- sql(s"ALTER TABLE $t PARTITION (id=1) RENAME TO PARTITION (id=2)")
- }.getMessage
- assert(errMsg.contains("ALTER TABLE ... RENAME TO PARTITION is not supported for v2 tables"))
+ createSinglePartTable(t)
+ val loc = "location1"
+ sql(s"ALTER TABLE $t ADD PARTITION (id = 2) LOCATION '$loc'")
+ sql(s"INSERT INTO $t PARTITION (id = 2) SELECT 'def'")
+ checkPartitions(t, Map("id" -> "1"), Map("id" -> "2"))
+ checkLocation(t, Map("id" -> "2"), loc)
+
+ sql(s"ALTER TABLE $t PARTITION (id = 2) RENAME TO PARTITION (id = 3)")
+ checkPartitions(t, Map("id" -> "1"), Map("id" -> "3"))
+ // `InMemoryPartitionTableCatalog` should keep the original location
+ checkLocation(t, Map("id" -> "3"), loc)
+ checkAnswer(sql(s"SELECT id, data FROM $t WHERE id = 3"), Row(3, "def"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala
index 0978126f27fd1..2dd80b7bb6a02 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CommandSuiteBase.scala
@@ -18,7 +18,10 @@
package org.apache.spark.sql.execution.command.v2
import org.apache.spark.SparkConf
-import org.apache.spark.sql.connector.{InMemoryPartitionTableCatalog, InMemoryTableCatalog}
+import org.apache.spark.sql.catalyst.analysis.ResolvePartitionSpec
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.connector.{InMemoryPartitionTable, InMemoryPartitionTableCatalog, InMemoryTableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Implicits, Identifier}
import org.apache.spark.sql.test.SharedSparkSession
/**
@@ -36,4 +39,26 @@ trait CommandSuiteBase extends SharedSparkSession {
override def sparkConf: SparkConf = super.sparkConf
.set(s"spark.sql.catalog.$catalog", classOf[InMemoryPartitionTableCatalog].getName)
.set(s"spark.sql.catalog.non_part_$catalog", classOf[InMemoryTableCatalog].getName)
+
+ def checkLocation(
+ t: String,
+ spec: TablePartitionSpec,
+ expected: String): Unit = {
+ import CatalogV2Implicits._
+
+ val tablePath = t.split('.')
+ val catalogName = tablePath.head
+ val namespaceWithTable = tablePath.tail
+ val namespaces = namespaceWithTable.init
+ val tableName = namespaceWithTable.last
+ val catalogPlugin = spark.sessionState.catalogManager.catalog(catalogName)
+ val partTable = catalogPlugin.asTableCatalog
+ .loadTable(Identifier.of(namespaces, tableName))
+ .asInstanceOf[InMemoryPartitionTable]
+ val ident = ResolvePartitionSpec.convertToPartIdent(spec, partTable.partitionSchema.fields)
+ val partMetadata = partTable.loadPartitionMetadata(ident)
+
+ assert(partMetadata.containsKey("location"))
+ assert(partMetadata.get("location") === expected)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CommandSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CommandSuiteBase.scala
index 39b4be61449cb..a1c808647c891 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CommandSuiteBase.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/CommandSuiteBase.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.hive.execution.command
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -30,4 +31,20 @@ trait CommandSuiteBase extends TestHiveSingleton {
def version: String = "Hive V1" // The prefix is added to test names
def catalog: String = CatalogManager.SESSION_CATALOG_NAME
def defaultUsing: String = "USING HIVE" // The clause is used in creating tables under testing
+
+ def checkLocation(
+ t: String,
+ spec: TablePartitionSpec,
+ expected: String): Unit = {
+ val tablePath = t.split('.')
+ val tableName = tablePath.last
+ val ns = tablePath.init.mkString(".")
+ val partSpec = spec.map { case (key, value) => s"$key = $value"}.mkString(", ")
+ val information =
+ spark.sql(s"SHOW TABLE EXTENDED IN $ns LIKE '$tableName' PARTITION($partSpec)")
+ .select("information")
+ .first().getString(0)
+ val location = information.split("\\r?\\n").filter(_.startsWith("Location:")).head
+ assert(location.endsWith(expected))
+ }
}