Skip to content

Commit

Permalink
reoslve
Browse files Browse the repository at this point in the history
  • Loading branch information
itholic committed Aug 9, 2023
2 parents 649180a + 9c9a4a8 commit 03058ed
Show file tree
Hide file tree
Showing 61 changed files with 1,210 additions and 951 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2399,18 +2399,27 @@ class Dataset[T] private[sql] (
.addAllColumnNames(cols.asJava)
}

private def buildDropDuplicates(
columns: Option[Seq[String]],
withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
val dropBuilder = builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setWithinWatermark(withinWaterMark)
if (columns.isDefined) {
dropBuilder.addAllColumnNames(columns.get.asJava)
} else {
dropBuilder.setAllColumnsAsKeys(true)
}
}

/**
* Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias
* for `distinct`.
*
* @group typedrel
* @since 3.4.0
*/
def dropDuplicates(): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setAllColumnsAsKeys(true)
}
def dropDuplicates(): Dataset[T] = buildDropDuplicates(None, withinWaterMark = false)

/**
* (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the
Expand All @@ -2419,11 +2428,8 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def dropDuplicates(colNames: Seq[String]): Dataset[T] = sparkSession.newDataset(encoder) {
builder =>
builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.addAllColumnNames(colNames.asJava)
def dropDuplicates(colNames: Seq[String]): Dataset[T] = {
buildDropDuplicates(Option(colNames), withinWaterMark = false)
}

/**
Expand All @@ -2443,16 +2449,14 @@ class Dataset[T] private[sql] (
*/
@scala.annotation.varargs
def dropDuplicates(col1: String, cols: String*): Dataset[T] = {
val colNames: Seq[String] = col1 +: cols
dropDuplicates(colNames)
dropDuplicates(col1 +: cols)
}

def dropDuplicatesWithinWatermark(): Dataset[T] = {
dropDuplicatesWithinWatermark(this.columns)
}
def dropDuplicatesWithinWatermark(): Dataset[T] =
buildDropDuplicates(None, withinWaterMark = true)

def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = {
throw new UnsupportedOperationException("dropDuplicatesWithinWatermark is not implemented.")
buildDropDuplicates(Option(colNames), withinWaterMark = true)
}

def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = {
Expand All @@ -2461,8 +2465,7 @@ class Dataset[T] private[sql] (

@scala.annotation.varargs
def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = {
val colNames: Seq[String] = col1 +: cols
dropDuplicatesWithinWatermark(colNames)
dropDuplicatesWithinWatermark(col1 +: cols)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,8 @@ class ExecutePlanResponseReattachableIterator(
iterFun(iter)
} catch {
case ex: StatusRuntimeException
if StatusProto
.fromThrowable(ex)
.getMessage
.contains("INVALID_HANDLE.OPERATION_NOT_FOUND") =>
if Option(StatusProto.fromThrowable(ex))
.exists(_.getMessage.contains("INVALID_HANDLE.OPERATION_NOT_FOUND")) =>
if (lastReturnedResponseId.isDefined) {
throw new IllegalStateException(
"OPERATION_NOT_FOUND on the server but responses were already received from it.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8056,6 +8056,46 @@ object functions {
}
// scalastyle:off line.size.limit

/**
* Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant,
* the caller must specify the output data type, and there is no automatic input type coercion.
* By default the returned UDF is deterministic. To change it to nondeterministic, call the API
* `UserDefinedFunction.asNondeterministic()`.
*
* Note that, although the Scala closure can have primitive-type function argument, it doesn't
* work well with null values. Because the Scala closure is passed in as Any type, there is no
* type information for the function arguments. Without the type information, Spark may blindly
* pass null to the Scala closure with primitive-type argument, and the closure will see the
* default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`,
* the result is 0 for null input.
*
* @param f
* A closure in Scala
* @param dataType
* The output data type of the UDF
*
* @group udf_funcs
* @since 3.5.0
*/
@deprecated(
"Scala `udf` method with return type parameter is deprecated. " +
"Please use Scala `udf` method without return type parameter.",
"3.0.0")
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
ScalarUserDefinedFunction(f, dataType)
}

/**
* Call an user-defined function.
*
* @group udf_funcs
* @since 3.5.0
*/
@scala.annotation.varargs
@deprecated("Use call_udf")
def callUDF(udfName: String, cols: Column*): Column =
call_function(udfName, cols: _*)

/**
* Call an user-defined function. Example:
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,26 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined)
}

test("dropDuplicatesWithinWatermark not supported in batch DataFrame") {
def testAndVerify(df: Dataset[_]): Unit = {
val exc = intercept[AnalysisException] {
df.write.format("noop").mode(SaveMode.Append).save()
}

assert(exc.getMessage.contains("dropDuplicatesWithinWatermark is not supported"))
assert(exc.getMessage.contains("batch DataFrames/DataSets"))
}

val result = spark.range(10).dropDuplicatesWithinWatermark()
testAndVerify(result)

val result2 = spark
.range(10)
.withColumn("newcol", col("id"))
.dropDuplicatesWithinWatermark("newcol")
testAndVerify(result2)
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ class FunctionTestSuite extends ConnectFunSuite {
pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes(), Map.empty[String, String].asJava),
pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes()))

testEquals("call_udf", callUDF("bob", lit(1)), call_udf("bob", lit(1)))

test("assert_true no message") {
val e = assert_true(a).expr
assert(e.hasUnresolvedFunction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConverters._

import org.apache.spark.api.java.function._
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder}
import org.apache.spark.sql.connect.client.util.QueryTest
import org.apache.spark.sql.functions.{col, struct, udf}
import org.apache.spark.sql.types.IntegerType

/**
* All tests in this class requires client UDF defined in this test class synced with the server.
Expand Down Expand Up @@ -250,4 +252,22 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
"b",
"c")
}

test("(deprecated) scala UDF with dataType") {
val session: SparkSession = spark
import session.implicits._
val fn = udf(((i: Long) => (i + 1).toInt), IntegerType)
checkDataset(session.range(2).select(fn($"id")).as[Int], 1, 2)
}

test("java UDF") {
val session: SparkSession = spark
import session.implicits._
val fn = udf(
new UDF2[Long, Long, Int] {
override def call(t1: Long, t2: Long): Int = (t1 + t2 + 1).toInt
},
IntegerType)
checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),

// functions
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),

Expand All @@ -214,14 +212,11 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
ProblemFilters.exclude[Problem](
"org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
// TODO(SPARK-44068): Support positional parameters in Scala connect client
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"),

// SparkSession#implicits
Expand Down Expand Up @@ -266,8 +261,6 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.streaming.StreamingQueryException.time"),

// Classes missing from streaming API
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.TestGroupState"),
ProblemFilters.exclude[MissingClassProblem](
Expand Down

This file was deleted.

This file was deleted.

Loading

0 comments on commit 03058ed

Please sign in to comment.