Skip to content

Commit

Permalink
[SPARK-41794][SQL] Add try_remainder function and re-enable column …
Browse files Browse the repository at this point in the history
…tests

### What changes were proposed in this pull request?
As part of re-enabling the ANSI mode tests for Spark Connect, we discovered that we don't have an equivalent for `try_*` for the remainder of operations. This patch adds the `try_remainder` function in Scala, Python, and Spark Connect and adds the required testing.

### Why are the changes needed?
ANSI and Spark 4

### Does this PR introduce _any_ user-facing change?
Yes, it adds the `try_remainder` function that behaves according to ANSI for division by zero.

### How was this patch tested?
Added new UT and E2E tests.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46434 from grundprinzip/grundprinzip/SPARK-41794.

Lead-authored-by: Martin Grund <martin.grund@databricks.com>
Co-authored-by: Martin Grund <grundprinzip@gmail.com>
Co-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
3 people authored and gengliangwang committed May 13, 2024
1 parent b14abb3 commit 8d8cc62
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,14 @@ object functions {
*/
def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right)

/**
* Returns the remainder of `dividend``/``divisor`. Its result is always null if `divisor` is 0.
*
* @group math_funcs
* @since 4.0.0
*/
def try_remainder(left: Column, right: Column): Column = Column.fn("try_remainder", left, right)

/**
* Returns `left``*``right` and the result is null on overflow. The acceptable input types are
* the same with the `*` operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
}
} catch {
case e: Exception =>
logWarning("StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.", e)
logWarning(
"StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.",
e)
lock.synchronized {
executionThread = Option.empty
listeners.forEach(remove(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,7 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.streaming.RemoteStreamingQuery$"),
// Skip client side listener specific class
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"
),
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"),

// Encoders are in the wrong JAR
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
Expand Down
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t
- `try_subtract`: identical to the add operator `-`, except that it returns `NULL` result instead of throwing an exception on integral value overflow.
- `try_multiply`: identical to the add operator `*`, except that it returns `NULL` result instead of throwing an exception on integral value overflow.
- `try_divide`: identical to the division operator `/`, except that it returns `NULL` result instead of throwing an exception on dividing 0.
- `try_remainder`: identical to the remainder operator `%`, except that it returns `NULL` result instead of throwing an exception on dividing 0.
- `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal/interval value overflow.
- `try_avg`: identical to the function `avg`, except that it returns `NULL` result instead of throwing an exception on decimal/interval value overflow.
- `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound.
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,13 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column:
try_divide.__doc__ = pysparkfuncs.try_divide.__doc__


def try_remainder(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("try_remainder", left, right)


try_remainder.__doc__ = pysparkfuncs.try_remainder.__doc__


def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("try_multiply", left, right)

Expand Down
52 changes: 51 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column:
| 4 months|
+--------------------------------------------------+

Example 3: Exception druing division, resulting in NULL when ANSI mode is on
Example 3: Exception during division, resulting in NULL when ANSI mode is on

>>> import pyspark.sql.functions as sf
>>> origin = spark.conf.get("spark.sql.ansi.enabled")
Expand All @@ -657,6 +657,56 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("try_divide", left, right)


@_try_remote_functions
def try_remainder(left: "ColumnOrName", right: "ColumnOrName") -> Column:
"""
Returns the remainder after `dividend`/`divisor`. Its result is
always null if `divisor` is 0.

.. versionadded:: 4.0.0

Parameters
----------
left : :class:`~pyspark.sql.Column` or str
dividend
right : :class:`~pyspark.sql.Column` or str
divisor

Examples
--------
Example 1: Integer divided by Integer.

>>> import pyspark.sql.functions as sf
>>> spark.createDataFrame(
... [(6000, 15), (3, 2), (1234, 0)], ["a", "b"]
... ).select(sf.try_remainder("a", "b")).show()
+-------------------+
|try_remainder(a, b)|
+-------------------+
| 0|
| 1|
| NULL|
+-------------------+

Example 2: Exception during division, resulting in NULL when ANSI mode is on

>>> import pyspark.sql.functions as sf
>>> origin = spark.conf.get("spark.sql.ansi.enabled")
>>> spark.conf.set("spark.sql.ansi.enabled", "true")
>>> try:
... df = spark.range(1)
... df.select(sf.try_remainder(df.id, sf.lit(0))).show()
... finally:
... spark.conf.set("spark.sql.ansi.enabled", origin)
+--------------------+
|try_remainder(id, 0)|
+--------------------+
| NULL|
+--------------------+
"""
return _invoke_function_over_columns("try_remainder", left, right)


@_try_remote_functions
def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column:
"""
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,8 @@ def test_column_accessor(self):
sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(),
)
self.assert_eq(
cdf.select(CF.col("z")[0], cdf.z[10], CF.col("z")[-10]).toPandas(),
sdf.select(SF.col("z")[0], sdf.z[10], SF.col("z")[-10]).toPandas(),
cdf.select(CF.col("z")[0], CF.get(cdf.z, 10), CF.get(CF.col("z"), -10)).toPandas(),
sdf.select(SF.col("z")[0], SF.get(sdf.z, 10), SF.get(SF.col("z"), -10)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.z.getItem(0), cdf.z.getItem(1), cdf["z"].getField(2)).toPandas(),
Expand Down Expand Up @@ -824,8 +824,12 @@ def test_column_arithmetic_ops(self):
)

self.assert_eq(
cdf.select(cdf.a % cdf["b"], cdf["a"] % 2, 12 % cdf.c).toPandas(),
sdf.select(sdf.a % sdf["b"], sdf["a"] % 2, 12 % sdf.c).toPandas(),
cdf.select(
cdf.a % cdf["b"], cdf["a"] % 2, CF.try_remainder(CF.lit(12), cdf.c)
).toPandas(),
sdf.select(
sdf.a % sdf["b"], sdf["a"] % 2, SF.try_remainder(SF.lit(12), sdf.c)
).toPandas(),
)

self.assert_eq(
Expand Down Expand Up @@ -1022,13 +1026,9 @@ def test_distributed_sequence_id(self):


if __name__ == "__main__":
import os
import unittest
from pyspark.sql.tests.connect.test_connect_column import * # noqa: F401

# TODO(SPARK-41794): Enable ANSI mode in this file.
os.environ["SPARK_ANSI_SQL_MODE"] = "false"

try:
import xmlrunner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ object FunctionRegistry {
// "try_*" function which always return Null instead of runtime error.
expression[TryAdd]("try_add"),
expression[TryDivide]("try_divide"),
expression[TryRemainder]("try_remainder"),
expression[TrySubtract]("try_subtract"),
expression[TryMultiply]("try_multiply"),
expression[TryElementAt]("try_element_at"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,43 @@ case class TryDivide(left: Expression, right: Expression, replacement: Expressio
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(dividend, divisor) - Returns the remainder after `expr1`/`expr2`. " +
"`dividend` must be a numeric. `divisor` must be a numeric.",
examples = """
Examples:
> SELECT _FUNC_(3, 2);
1
> SELECT _FUNC_(2L, 2L);
0
> SELECT _FUNC_(3.0, 2.0);
1.0
> SELECT _FUNC_(1, 0);
NULL
""",
since = "4.0.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class TryRemainder(left: Expression, right: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) = this(left, right,
(left.dataType, right.dataType) match {
case (_: NumericType, _: NumericType) => Remainder(left, right, EvalMode.TRY)
// TODO: support TRY eval mode on datetime arithmetic expressions.
case _ => TryEval(Remainder(left, right, EvalMode.ANSI))
}
)

override def prettyName: String = "try_remainder"

override def parameters: Seq[Expression] = Seq(left, right)

override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(replacement = newChild)
}
}

@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns `expr1`-`expr2` and the result is null on overflow. " +
"The acceptable input types are the same with the `-` operator.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,10 @@ case class Remainder(

override def inputType: AbstractDataType = NumericType

// `try_remainder` has exactly the same behavior as the legacy divide, so here it only executes
// the error code path when `evalMode` is `ANSI`.
protected override def failOnError: Boolean = evalMode == EvalMode.ANSI

override def symbol: String = "%"
override def decimalMethod: String = "remainder"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("try_remainder") {
Seq(
(3.0, 2.0, 1.0),
(1.0, 0.0, null),
(-1.0, 0.0, null)
).foreach { case (a, b, expected) =>
val left = Literal(a)
val right = Literal(b)
val input = Remainder(left, right, EvalMode.TRY)
checkEvaluation(input, expected)
}
}

test("try_subtract") {
Seq(
(1, 1, 0),
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,15 @@ object functions {
*/
def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right)

/**
* Returns the remainder of `dividend``/``divisor`. Its result is
* always null if `divisor` is 0.
*
* @group math_funcs
* @since 4.0.0
*/
def try_remainder(left: Column, right: Column): Column = Column.fn("try_remainder", left, right)

/**
* Returns `left``*``right` and the result is null on overflow. The acceptable input types are
* the same with the `*` operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@
| org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct<try_element_at(array(1, 2, 3), 2):int> |
| org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct<try_multiply(2, 3):int> |
| org.apache.spark.sql.catalyst.expressions.TryReflect | try_reflect | SELECT try_reflect('java.util.UUID', 'randomUUID') | struct<try_reflect(java.util.UUID, randomUUID):string> |
| org.apache.spark.sql.catalyst.expressions.TryRemainder | try_remainder | SELECT try_remainder(3, 2) | struct<try_remainder(3, 2):int> |
| org.apache.spark.sql.catalyst.expressions.TrySubtract | try_subtract | SELECT try_subtract(2, 1) | struct<try_subtract(2, 1):int> |
| org.apache.spark.sql.catalyst.expressions.TryToBinary | try_to_binary | SELECT try_to_binary('abc', 'utf-8') | struct<try_to_binary(abc, utf-8):binary> |
| org.apache.spark.sql.catalyst.expressions.TryToNumber | try_to_number | SELECT try_to_number('454', '999') | struct<try_to_number(454, 999):decimal(3,0)> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,17 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
df1.select(try_divide(make_interval(col("year"), col("month")), lit(0))))
}

test("try_remainder") {
val df = Seq((10, 3), (5, 5), (5, 0)).toDF("birth", "age")
checkAnswer(df.selectExpr("try_remainder(birth, age)"), Seq(Row(1), Row(0), Row(null)))

val dfDecimal = Seq(
(BigDecimal(10), BigDecimal(3)),
(BigDecimal(5), BigDecimal(5)),
(BigDecimal(5), BigDecimal(0))).toDF("birth", "age")
checkAnswer(dfDecimal.selectExpr("try_remainder(birth, age)"), Seq(Row(1), Row(0), Row(null)))
}

test("try_element_at") {
val df = Seq((Array(1, 2, 3), 2)).toDF("a", "b")
checkAnswer(df.selectExpr("try_element_at(a, b)"), Seq(Row(2)))
Expand Down

0 comments on commit 8d8cc62

Please sign in to comment.