From c5293ecb017b55ff661ea05353e4463a08d0073c Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 4 Sep 2024 08:05:56 -0400 Subject: [PATCH] [SPARK-49426][CONNECT][SQL] Create a shared interface for DataFrameWriterV2 ### What changes were proposed in this pull request? This PR creates a shared interface for DataFrameWriterV2. ### Why are the changes needed? We are creating a shared Scala Spark SQL interface for Classic and Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47962 from hvanhovell/SPARK-49426. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../apache/spark/sql/DataFrameWriterV2.scala | 293 ------------------ .../scala/org/apache/spark/sql/Dataset.scala | 24 +- .../sql/internal/DataFrameWriterV2Impl.scala | 124 ++++++++ project/MimaExcludes.scala | 7 +- .../apache/spark/sql/DataFrameWriterV2.scala | 213 ++----------- .../org/apache/spark/sql/api/Dataset.scala | 23 +- .../analysis/noSuchItemsExceptions.scala | 9 + .../CannotReplaceMissingTableException.scala | 33 -- .../scala/org/apache/spark/sql/Dataset.scala | 24 +- .../sql/internal/DataFrameWriterV2Impl.scala | 245 +++++++++++++++ 10 files changed, 440 insertions(+), 555 deletions(-) delete mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala rename sql/{core => api}/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala (53%) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala deleted file mode 100644 index 3f9b224003914..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ /dev/null @@ -1,293 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.mutable -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.connect.proto - -/** - * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 - * API. - * - * @since 3.4.0 - */ -@Experimental -final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T]) - extends CreateTableWriter[T] { - import ds.sparkSession.RichColumn - - private var provider: Option[String] = None - - private val options = new mutable.HashMap[String, String]() - - private val properties = new mutable.HashMap[String, String]() - - private var partitioning: Option[Seq[proto.Expression]] = None - - private var clustering: Option[Seq[String]] = None - - private var overwriteCondition: Option[proto.Expression] = None - - override def using(provider: String): CreateTableWriter[T] = { - this.provider = Some(provider) - this - } - - override def option(key: String, value: String): DataFrameWriterV2[T] = { - this.options.put(key, value) - this - } - - override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { - options.foreach { case (key, value) => - this.options.put(key, value) - } - this - } - - override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = { - this.options(options.asScala) - this - } - - override def tableProperty(property: String, value: String): CreateTableWriter[T] = { - this.properties.put(property, value) - this - } - - @scala.annotation.varargs - override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { - val asTransforms = (column +: columns).map(_.expr) - this.partitioning = Some(asTransforms) - this - } - - @scala.annotation.varargs - override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = { - this.clustering = Some(colName +: colNames) - this - } - - override def create(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) - } - - override def replace(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE) - } - - override def createOrReplace(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) - } - - /** - * Append the contents of the data frame to the output table. - * - * If the output table does not exist, this operation will fail. The data frame will be - * validated to ensure it is compatible with the existing table. - */ - def append(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND) - } - - /** - * Overwrite rows matching the given filter condition with the contents of the data frame in the - * output table. - * - * If the output table does not exist, this operation will fail. The data frame will be - * validated to ensure it is compatible with the existing table. - */ - def overwrite(condition: Column): Unit = { - overwriteCondition = Some(condition.expr) - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) - } - - /** - * Overwrite all partition for which the data frame contains at least one row with the contents - * of the data frame in the output table. - * - * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces - * partitions dynamically depending on the contents of the data frame. - * - * If the output table does not exist, this operation will fail. The data frame will be - * validated to ensure it is compatible with the existing table. - */ - def overwritePartitions(): Unit = { - executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS) - } - - private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = { - val builder = proto.WriteOperationV2.newBuilder() - - builder.setInput(ds.plan.getRoot) - builder.setTableName(table) - provider.foreach(builder.setProvider) - - partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava)) - clustering.foreach(columns => builder.addAllClusteringColumns(columns.asJava)) - - options.foreach { case (k, v) => - builder.putOptions(k, v) - } - properties.foreach { case (k, v) => - builder.putTableProperties(k, v) - } - - builder.setMode(mode) - - overwriteCondition.foreach(builder.setOverwriteCondition) - - ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperationV2(builder).build()) - } -} - -/** - * Configuration methods common to create/replace operations and insert/overwrite operations. - * @tparam R - * builder type to return - * @since 3.4.0 - */ -trait WriteConfigMethods[R] { - - /** - * Add a write option. - * - * @since 3.4.0 - */ - def option(key: String, value: String): R - - /** - * Add a boolean output option. - * - * @since 3.4.0 - */ - def option(key: String, value: Boolean): R = option(key, value.toString) - - /** - * Add a long output option. - * - * @since 3.4.0 - */ - def option(key: String, value: Long): R = option(key, value.toString) - - /** - * Add a double output option. - * - * @since 3.4.0 - */ - def option(key: String, value: Double): R = option(key, value.toString) - - /** - * Add write options from a Scala Map. - * - * @since 3.4.0 - */ - def options(options: scala.collection.Map[String, String]): R - - /** - * Add write options from a Java Map. - * - * @since 3.4.0 - */ - def options(options: java.util.Map[String, String]): R -} - -/** - * Trait to restrict calls to create and replace operations. - * - * @since 3.4.0 - */ -trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] { - - /** - * Create a new table from the contents of the data frame. - * - * The new table's schema, partition layout, properties, and other configuration will be based - * on the configuration set on this writer. - * - * If the output table exists, this operation will fail. - */ - def create(): Unit - - /** - * Replace an existing table with the contents of the data frame. - * - * The existing table's schema, partition layout, properties, and other configuration will be - * replaced with the contents of the data frame and the configuration set on this writer. - * - * If the output table does not exist, this operation will fail. - */ - def replace(): Unit - - /** - * Create a new table or replace an existing table with the contents of the data frame. - * - * The output table's schema, partition layout, properties, and other configuration will be - * based on the contents of the data frame and the configuration set on this writer. If the - * table exists, its configuration and data will be replaced. - */ - def createOrReplace(): Unit - - /** - * Partition the output table created by `create`, `createOrReplace`, or `replace` using the - * given columns or transforms. - * - * When specified, the table data will be stored by these values for efficient reads. - * - * For example, when a table is partitioned by day, it may be stored in a directory layout like: - *
  • `table/day=2019-06-01/`
  • `table/day=2019-06-02/`
- * - * Partitioning is one of the most widely used techniques to optimize physical data layout. It - * provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number of - * distinct values in each column should typically be less than tens of thousands. - * - * @since 3.4.0 - */ - @scala.annotation.varargs - def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] - - /** - * Clusters the output by the given columns on the storage. The rows with matching values in the - * specified clustering columns will be consolidated within the same group. - * - * For instance, if you cluster a dataset by date, the data sharing the same date will be stored - * together in a file. This arrangement improves query efficiency when you apply selective - * filters to these clustering columns, thanks to data skipping. - * - * @since 4.0.0 - */ - @scala.annotation.varargs - def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] - - /** - * Specifies a provider for the underlying output data source. Spark's default catalog supports - * "parquet", "json", etc. - * - * @since 3.4.0 - */ - def using(provider: String): CreateTableWriter[T] - - /** - * Add a table property. - */ - def tableProperty(property: String, value: String): CreateTableWriter[T] -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 44504487f80e7..778cd153ec2ed 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevel import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} -import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} +import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel @@ -1018,27 +1018,9 @@ class Dataset[T] private[sql] ( new DataFrameWriterImpl[T](this) } - /** - * Create a write configuration builder for v2 sources. - * - * This builder is used to configure and execute write operations. For example, to append to an - * existing table, run: - * - * {{{ - * df.writeTo("catalog.db.table").append() - * }}} - * - * This can also be used to create or replace existing tables: - * - * {{{ - * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() - * }}} - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { - new DataFrameWriterV2[T](table, this) + new DataFrameWriterV2Impl[T](table, this) } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala new file mode 100644 index 0000000000000..4afa8b6d566c5 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset} + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 + * API. + * + * @since 3.4.0 + */ +@Experimental +final class DataFrameWriterV2Impl[T] private[sql] (table: String, ds: Dataset[T]) + extends DataFrameWriterV2[T] { + import ds.sparkSession.RichColumn + + private val builder = proto.WriteOperationV2 + .newBuilder() + .setInput(ds.plan.getRoot) + .setTableName(table) + + /** @inheritdoc */ + override def using(provider: String): this.type = { + builder.setProvider(provider) + this + } + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = { + builder.putOptions(key, value) + this + } + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = { + builder.putAllOptions(options.asJava) + this + } + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { + builder.putAllOptions(options) + this + } + + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type = { + builder.putTableProperties(property, value) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): this.type = { + builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = { + builder.addAllClusteringColumns((colName +: colNames).asJava) + this + } + + /** @inheritdoc */ + override def create(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) + } + + /** @inheritdoc */ + override def replace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE) + } + + /** @inheritdoc */ + override def createOrReplace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) + } + + /** @inheritdoc */ + def append(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND) + } + + /** @inheritdoc */ + def overwrite(condition: Column): Unit = { + builder.setOverwriteCondition(condition.expr) + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) + } + + /** @inheritdoc */ + def overwritePartitions(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS) + } + + private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = { + val command = proto.Command + .newBuilder() + .setWriteOperationV2(builder.setMode(mode)) + .build() + ds.sparkSession.execute(command) + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 21638e4816309..03b7b6efca1b0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -126,7 +126,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"), // SPARK-49425: Create a shared DataFrameWriter interface. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter"), + + // SPARK-49426: Shared DataFrameWriterV2 + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CreateTableWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriterV2"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.WriteConfigMethods"), ) // Default exclude rules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala similarity index 53% rename from sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 576d8276b56ef..ddc89178cd835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -14,150 +14,52 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql -import scala.collection.mutable -import scala.jdk.CollectionConverters._ +import java.util import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec} -import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ -import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform} -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} /** - * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. + * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage + * using the v2 API. * * @since 3.0.0 */ @Experimental -final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) - extends CreateTableWriter[T] { - - private val df: DataFrame = ds.toDF() - - private val sparkSession = ds.sparkSession - import sparkSession.expression - - private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) - - private val logicalPlan = df.queryExecution.logical - - private var provider: Option[String] = None - - private val options = new mutable.HashMap[String, String]() +abstract class DataFrameWriterV2[T] extends CreateTableWriter[T] { + /** @inheritdoc */ + override def using(provider: String): this.type - private val properties = new mutable.HashMap[String, String]() + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = option(key, value.toString) - private var partitioning: Option[Seq[Transform]] = None + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = option(key, value.toString) - private var clustering: Option[ClusterByTransform] = None + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = option(key, value.toString) - override def using(provider: String): CreateTableWriter[T] = { - this.provider = Some(provider) - this - } + /** @inheritdoc */ + override def option(key: String, value: String): this.type - override def option(key: String, value: String): DataFrameWriterV2[T] = { - this.options.put(key, value) - this - } + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type - override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = { - options.foreach { - case (key, value) => - this.options.put(key, value) - } - this - } + /** @inheritdoc */ + override def options(options: util.Map[String, String]): this.type - override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = { - this.options(options.asScala) - this - } - - override def tableProperty(property: String, value: String): CreateTableWriter[T] = { - this.properties.put(property, value) - this - } + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type + /** @inheritdoc */ @scala.annotation.varargs - override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = { - def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) - - val asTransforms = (column +: columns).map(expression).map { - case PartitionTransform.YEARS(Seq(attr: Attribute)) => - LogicalExpressions.years(ref(attr.name)) - case PartitionTransform.MONTHS(Seq(attr: Attribute)) => - LogicalExpressions.months(ref(attr.name)) - case PartitionTransform.DAYS(Seq(attr: Attribute)) => - LogicalExpressions.days(ref(attr.name)) - case PartitionTransform.HOURS(Seq(attr: Attribute)) => - LogicalExpressions.hours(ref(attr.name)) - case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) => - LogicalExpressions.bucket(numBuckets, Array(ref(attr.name))) - case PartitionTransform.BUCKET(Seq(numBuckets, e)) => - throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString) - case attr: Attribute => - LogicalExpressions.identity(ref(attr.name)) - case expr => - throw QueryCompilationErrors.invalidPartitionTransformationError(expr) - } - - this.partitioning = Some(asTransforms) - validatePartitioning() - this - } + override def partitionedBy(column: Column, columns: Column*): this.type + /** @inheritdoc */ @scala.annotation.varargs - override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = { - this.clustering = - Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col)))) - validatePartitioning() - this - } - - /** - * Validate that clusterBy is not used with partitionBy. - */ - private def validatePartitioning(): Unit = { - if (partitioning.nonEmpty && clustering.nonEmpty) { - throw QueryCompilationErrors.clusterByWithPartitionedBy() - } - } - - override def create(): Unit = { - val tableSpec = UnresolvedTableSpec( - properties = properties.toMap, - provider = provider, - optionExpression = OptionList(Seq.empty), - location = None, - comment = None, - serde = None, - external = false) - runCommand( - CreateTableAsSelect( - UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty) ++ clustering, - logicalPlan, - tableSpec, - options.toMap, - false)) - } - - override def replace(): Unit = { - internalReplace(orCreate = false) - } - - override def createOrReplace(): Unit = { - internalReplace(orCreate = true) - } - + override def clusterBy(colName: String, colNames: String*): this.type /** * Append the contents of the data frame to the output table. @@ -169,12 +71,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist */ @throws(classOf[NoSuchTableException]) - def append(): Unit = { - val append = AppendData.byName( - UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), - logicalPlan, options.toMap) - runCommand(append) - } + def append(): Unit /** * Overwrite rows matching the given filter condition with the contents of the data frame in @@ -187,12 +84,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist */ @throws(classOf[NoSuchTableException]) - def overwrite(condition: Column): Unit = { - val overwrite = OverwriteByExpression.byName( - UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), - logicalPlan, expression(condition), options.toMap) - runCommand(overwrite) - } + def overwrite(condition: Column): Unit /** * Overwrite all partition for which the data frame contains at least one row with the contents @@ -208,56 +100,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist */ @throws(classOf[NoSuchTableException]) - def overwritePartitions(): Unit = { - val dynamicOverwrite = OverwritePartitionsDynamic.byName( - UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), - logicalPlan, options.toMap) - runCommand(dynamicOverwrite) - } - - /** - * Wrap an action to track the QueryExecution and time cost, then report to the user-registered - * callback functions. - */ - private def runCommand(command: LogicalPlan): Unit = { - val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker) - qe.assertCommandExecuted() - } - - private def internalReplace(orCreate: Boolean): Unit = { - val tableSpec = UnresolvedTableSpec( - properties = properties.toMap, - provider = provider, - optionExpression = OptionList(Seq.empty), - location = None, - comment = None, - serde = None, - external = false) - runCommand(ReplaceTableAsSelect( - UnresolvedIdentifier(tableName), - partitioning.getOrElse(Seq.empty) ++ clustering, - logicalPlan, - tableSpec, - writeOptions = options.toMap, - orCreate = orCreate)) - } -} - -private object PartitionTransform { - class ExtractTransform(name: String) { - private val NAMES = Seq(name) - - def unapply(e: Expression): Option[Seq[Expression]] = e match { - case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children) - case _ => None - } - } - - val HOURS = new ExtractTransform("hours") - val DAYS = new ExtractTransform("days") - val MONTHS = new ExtractTransform("months") - val YEARS = new ExtractTransform("years") - val BUCKET = new ExtractTransform("bucket") + def overwritePartitions(): Unit } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index c98260976b831..38abb63c9dcc3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -23,7 +23,7 @@ import _root_.java.util import org.apache.spark.annotation.{DeveloperApi, Stable} import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction} -import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, Encoder, Observation, Row, TypedColumn} +import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, Observation, Row, TypedColumn} import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel @@ -2837,6 +2837,27 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { protected def createTempView(viewName: String, replace: Boolean, global: Boolean): Unit + /** + * Create a write configuration builder for v2 sources. + * + * This builder is used to configure and execute write operations. For example, to append to an + * existing table, run: + * + * {{{ + * df.writeTo("catalog.db.table").append() + * }}} + * + * This can also be used to create or replace existing tables: + * + * {{{ + * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() + * }}} + * + * @group basic + * @since 3.0.0 + */ + def writeTo(table: String): DataFrameWriterV2[T] + /** * Returns the content of the Dataset as a Dataset of JSON strings. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala index 8977d0be24d77..1b836da45e802 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala @@ -201,3 +201,12 @@ class NoSuchIndexException private( this("INDEX_NOT_FOUND", Map("indexName" -> indexName, "tableName" -> tableName), cause) } } + +class CannotReplaceMissingTableException( + tableIdentifier: Identifier, + cause: Option[Throwable] = None) + extends AnalysisException( + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + messageParameters = Map("relationName" + -> quoteNameParts((tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)), + cause = cause) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala deleted file mode 100644 index f3e0c0aca29ca..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.util.quoteNameParts -import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.util.ArrayImplicits._ - -class CannotReplaceMissingTableException( - tableIdentifier: Identifier, - cause: Option[Throwable] = None) - extends AnalysisException( - errorClass = "TABLE_OR_VIEW_NOT_FOUND", - messageParameters = Map("relationName" - -> quoteNameParts((tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)), - cause = cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9ae89e84df874..f62331710d637 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -60,7 +60,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, SQLConf, ToScalaUDF} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, SQLConf, ToScalaUDF} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter @@ -1595,25 +1595,7 @@ class Dataset[T] private[sql]( new DataFrameWriterImpl[T](this) } - /** - * Create a write configuration builder for v2 sources. - * - * This builder is used to configure and execute write operations. For example, to append to an - * existing table, run: - * - * {{{ - * df.writeTo("catalog.db.table").append() - * }}} - * - * This can also be used to create or replace existing tables: - * - * {{{ - * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() - * }}} - * - * @group basic - * @since 3.0.0 - */ + /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { // TODO: streaming could be adapted to use this interface if (isStreaming) { @@ -1621,7 +1603,7 @@ class Dataset[T] private[sql]( errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", messageParameters = Map("methodName" -> toSQLId("writeTo"))) } - new DataFrameWriterV2[T](table, this) + new DataFrameWriterV2Impl[T](table, this) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala new file mode 100644 index 0000000000000..0a19e6c47afa9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util + +import scala.collection.mutable +import scala.jdk.CollectionConverters.MapHasAsScala + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, DataFrame, DataFrameWriterV2, Dataset} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.IntegerType + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API. + * + * @since 3.0.0 + */ +@Experimental +final class DataFrameWriterV2Impl[T] private[sql](table: String, ds: Dataset[T]) + extends DataFrameWriterV2[T] { + + private val df: DataFrame = ds.toDF() + + private val sparkSession = ds.sparkSession + import sparkSession.expression + + private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) + + private val logicalPlan = df.queryExecution.logical + + private var provider: Option[String] = None + + private val options = new mutable.HashMap[String, String]() + + private val properties = new mutable.HashMap[String, String]() + + private var partitioning: Option[Seq[Transform]] = None + + private var clustering: Option[ClusterByTransform] = None + + /** @inheritdoc */ + override def using(provider: String): this.type = { + this.provider = Some(provider) + this + } + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = { + this.options.put(key, value) + this + } + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = { + options.foreach { + case (key, value) => + this.options.put(key, value) + } + this + } + + /** @inheritdoc */ + override def options(options: util.Map[String, String]): this.type = { + this.options(options.asScala) + this + } + + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type = { + this.properties.put(property, value) + this + } + + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): this.type = { + def ref(name: String): NamedReference = LogicalExpressions.parseReference(name) + + val asTransforms = (column +: columns).map(expression).map { + case PartitionTransform.YEARS(Seq(attr: Attribute)) => + LogicalExpressions.years(ref(attr.name)) + case PartitionTransform.MONTHS(Seq(attr: Attribute)) => + LogicalExpressions.months(ref(attr.name)) + case PartitionTransform.DAYS(Seq(attr: Attribute)) => + LogicalExpressions.days(ref(attr.name)) + case PartitionTransform.HOURS(Seq(attr: Attribute)) => + LogicalExpressions.hours(ref(attr.name)) + case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) => + LogicalExpressions.bucket(numBuckets, Array(ref(attr.name))) + case PartitionTransform.BUCKET(Seq(numBuckets, e)) => + throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString) + case attr: Attribute => + LogicalExpressions.identity(ref(attr.name)) + case expr => + throw QueryCompilationErrors.invalidPartitionTransformationError(expr) + } + + this.partitioning = Some(asTransforms) + validatePartitioning() + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = { + this.clustering = + Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col)))) + validatePartitioning() + this + } + + /** + * Validate that clusterBy is not used with partitionBy. + */ + private def validatePartitioning(): Unit = { + if (partitioning.nonEmpty && clustering.nonEmpty) { + throw QueryCompilationErrors.clusterByWithPartitionedBy() + } + } + + /** @inheritdoc */ + override def create(): Unit = { + val tableSpec = UnresolvedTableSpec( + properties = properties.toMap, + provider = provider, + optionExpression = OptionList(Seq.empty), + location = None, + comment = None, + serde = None, + external = false) + runCommand( + CreateTableAsSelect( + UnresolvedIdentifier(tableName), + partitioning.getOrElse(Seq.empty) ++ clustering, + logicalPlan, + tableSpec, + options.toMap, + false)) + } + + /** @inheritdoc */ + override def replace(): Unit = { + internalReplace(orCreate = false) + } + + /** @inheritdoc */ + override def createOrReplace(): Unit = { + internalReplace(orCreate = true) + } + + /** @inheritdoc */ + @throws(classOf[NoSuchTableException]) + def append(): Unit = { + val append = AppendData.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)), + logicalPlan, options.toMap) + runCommand(append) + } + + /** @inheritdoc */ + @throws(classOf[NoSuchTableException]) + def overwrite(condition: Column): Unit = { + val overwrite = OverwriteByExpression.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, expression(condition), options.toMap) + runCommand(overwrite) + } + + /** @inheritdoc */ + @throws(classOf[NoSuchTableException]) + def overwritePartitions(): Unit = { + val dynamicOverwrite = OverwritePartitionsDynamic.byName( + UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)), + logicalPlan, options.toMap) + runCommand(dynamicOverwrite) + } + + /** + * Wrap an action to track the QueryExecution and time cost, then report to the user-registered + * callback functions. + */ + private def runCommand(command: LogicalPlan): Unit = { + val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker) + qe.assertCommandExecuted() + } + + private def internalReplace(orCreate: Boolean): Unit = { + val tableSpec = UnresolvedTableSpec( + properties = properties.toMap, + provider = provider, + optionExpression = OptionList(Seq.empty), + location = None, + comment = None, + serde = None, + external = false) + runCommand(ReplaceTableAsSelect( + UnresolvedIdentifier(tableName), + partitioning.getOrElse(Seq.empty) ++ clustering, + logicalPlan, + tableSpec, + writeOptions = options.toMap, + orCreate = orCreate)) + } +} + +private object PartitionTransform { + class ExtractTransform(name: String) { + private val NAMES = Seq(name) + + def unapply(e: Expression): Option[Seq[Expression]] = e match { + case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children) + case _ => None + } + } + + val HOURS = new ExtractTransform("hours") + val DAYS = new ExtractTransform("days") + val MONTHS = new ExtractTransform("months") + val YEARS = new ExtractTransform("years") + val BUCKET = new ExtractTransform("bucket") +}