Skip to content

Commit

Permalink
[SQL] Support Update in DataFrameWriterV2
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Jul 1, 2024
1 parent df13ca0 commit 92c65f2
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 0 deletions.
20 changes: 20 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4136,6 +4136,26 @@ class Dataset[T] private[sql](
new MergeIntoWriter[T](table, this, condition)
}

def update(set: Map[String, Column], where: Column): Unit = {
if (isStreaming) {
logicalPlan.failAnalysis(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}

new UpdateWriter(this, set, Some(where.expr)).update()
}

def update(set: Map[String, Column]): Unit = {
if (isStreaming) {
logicalPlan.failAnalysis(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}

new UpdateWriter(this, set, None).update()
}

/**
* Interface for saving the content of the streaming Dataset out into external storage.
*
Expand Down
59 changes: 59 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable}
import org.apache.spark.sql.functions.expr

/**
* `UpdateWriter` provides methods to define and execute an update action based
* on the specified conditions.
*
* @tparam T the type of data in the Dataset.
* @param ds the Dataset to update.
* @param set A Map of column names to Column expressions representing the updates to be applied.
* @param where the update condition.
*
* @since 4.0.0
*/
@Experimental
class UpdateWriter[T] private[sql](
ds: Dataset[T],
set: Map[String, Column],
where: Option[Expression]) {

private val df: DataFrame = ds.toDF()

private val sparkSession = ds.sparkSession

private val logicalPlan = df.queryExecution.logical

/**
* Executes the update operation.
*/
def update(): Unit = {
val merge = UpdateTable(
logicalPlan,
set.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
where)
val qe = sparkSession.sessionState.executePlan(merge)
qe.assertCommandExecuted()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._

class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {

import testImplicits._

test("basic update") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
|{ "pk": 2, "salary": 150, "dep": 'software' }
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.table(tableNameAsString)
.update(Map("salary" -> lit(-1)), $"pk" >= 2)

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 300, "hr"),
Row(2, -1, "software"),
Row(3, -1, "hr")))
}

test("update without where clause") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
|{ "pk": 2, "salary": 150, "dep": 'software' }
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

spark.table(tableNameAsString)
.update(Map("dep" -> lit("software")))

checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Row(1, 300, "software"),
Row(2, 150, "software"),
Row(3, 120, "software")))
}
}

0 comments on commit 92c65f2

Please sign in to comment.