diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c7511737b2b3f..7c6ad0ea641ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -4136,6 +4136,26 @@ class Dataset[T] private[sql]( new MergeIntoWriter[T](table, this, condition) } + def update(set: Map[String, Column], where: Column): Unit = { + if (isStreaming) { + logicalPlan.failAnalysis( + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + messageParameters = Map("methodName" -> toSQLId("update"))) + } + + new UpdateWriter(this, set, Some(where.expr)).update() + } + + def update(set: Map[String, Column]): Unit = { + if (isStreaming) { + logicalPlan.failAnalysis( + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + messageParameters = Map("methodName" -> toSQLId("update"))) + } + + new UpdateWriter(this, set, None).update() + } + /** * Interface for saving the content of the streaming Dataset out into external storage. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala new file mode 100644 index 0000000000000..eef951c782c5e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} +import org.apache.spark.sql.functions.expr + +/** + * `UpdateWriter` provides methods to define and execute an update action based + * on the specified conditions. + * + * @tparam T the type of data in the Dataset. + * @param ds the Dataset to update. + * @param set A Map of column names to Column expressions representing the updates to be applied. + * @param where the update condition. + * + * @since 4.0.0 + */ +@Experimental +class UpdateWriter[T] private[sql]( + ds: Dataset[T], + set: Map[String, Column], + where: Option[Expression]) { + + private val df: DataFrame = ds.toDF() + + private val sparkSession = ds.sparkSession + + private val logicalPlan = df.queryExecution.logical + + /** + * Executes the update operation. + */ + def update(): Unit = { + val merge = UpdateTable( + logicalPlan, + set.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq, + where) + val qe = sparkSession.sessionState.executePlan(merge) + qe.assertCommandExecuted() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala new file mode 100644 index 0000000000000..6484190bc4d24 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ + +class UpdateDataFrameSuite extends RowLevelOperationSuiteBase { + + import testImplicits._ + + test("basic update") { + createAndInitTable("pk INT, salary INT, dep STRING", + """{ "pk": 1, "salary": 300, "dep": 'hr' } + |{ "pk": 2, "salary": 150, "dep": 'software' } + |{ "pk": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + spark.table(tableNameAsString) + .update(Map("salary" -> lit(-1)), $"pk" >= 2) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 300, "hr"), + Row(2, -1, "software"), + Row(3, -1, "hr"))) + } + + test("update without where clause") { + createAndInitTable("pk INT, salary INT, dep STRING", + """{ "pk": 1, "salary": 300, "dep": 'hr' } + |{ "pk": 2, "salary": 150, "dep": 'software' } + |{ "pk": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + spark.table(tableNameAsString) + .update(Map("dep" -> lit("software"))) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 300, "software"), + Row(2, 150, "software"), + Row(3, 120, "software"))) + } +}