Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49306][PYTHON][SQL] Create DataFrame API support for new 'zeroifnull' and 'nullifzero' SQL functions #47851

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,10 @@ abstract class SparkFunSuite
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
if (expected.callSitePattern.nonEmpty) {
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ Conditional Functions
ifnull
nanvl
nullif
nullifzero
nvl
nvl2
when
zeroifnull


Predicate Functions
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,13 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
nullif.__doc__ = pysparkfuncs.nullif.__doc__


def nullifzero(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("nullifzero", col)


nullifzero.__doc__ = pysparkfuncs.nullifzero.__doc__


def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
return _invoke_function_over_columns("nvl", col1, col2)

Expand All @@ -3935,6 +3942,13 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co
nvl2.__doc__ = pysparkfuncs.nvl2.__doc__


def zeroifnull(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("zeroifnull", col)


zeroifnull.__doc__ = pysparkfuncs.zeroifnull.__doc__


def aes_encrypt(
input: "ColumnOrName",
key: "ColumnOrName",
Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20681,6 +20681,31 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
return _invoke_function_over_columns("nullif", col1, col2)


@_try_remote_functions
def nullifzero(col: "ColumnOrName") -> Column:
"""
Returns null if `col` is equal to zero, or `col` otherwise.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str

Examples
--------
>>> df = spark.createDataFrame([(0,), (1,)], ["a"])
>>> df.select(nullifzero(df.a).alias("result")).show()
+------+
|result|
+------+
| None|
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you expect None if the function nullifzero() should return NULL for 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this; it was indeed a typo and should say NULL instead of None.

| 1|
+------+
"""
return _invoke_function_over_columns("nullifzero", col)


@_try_remote_functions
def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
"""
Expand Down Expand Up @@ -20724,6 +20749,31 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co
return _invoke_function_over_columns("nvl2", col1, col2, col3)


@_try_remote_functions
def zeroifnull(col: "ColumnOrName") -> Column:
"""
Returns zero if `col` is null, or `col` otherwise.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str

Examples
--------
>>> df = spark.createDataFrame([(None,), (1,)], ["a"])
>>> df.select(zeroifnull(df.a).alias("result")).show()
+------+
|result|
+------+
| 0|
| 1|
+------+
"""
return _invoke_function_over_columns("zeroifnull", col)


@_try_remote_functions
def aes_encrypt(
input: "ColumnOrName",
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyspark.sql import Row, Window, functions as F, types
from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
from pyspark.testing.utils import have_numpy

Expand Down Expand Up @@ -1593,6 +1594,15 @@ class IntEnum(Enum):
for r, c, e in zip(result, cols, expected):
self.assertEqual(r, e, str(c))

def test_nullifzero_zeroifnull(self):
df = self.spark.createDataFrame([(0,), (1,)], ["a"])
result = df.select(nullifzero(df.a).alias("r")).collect()
self.assertEqual([Row(r=None), Row(r=1)], result)

df = self.spark.createDataFrame([(None,), (1,)], ["a"])
result = df.select(zeroifnull(df.a).alias("r")).collect()
self.assertEqual([Row(r=0), Row(r=1)], result)


class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
Expand Down
16 changes: 16 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7832,6 +7832,14 @@ object functions {
*/
def nullif(col1: Column, col2: Column): Column = Column.fn("nullif", col1, col2)

/**
* Returns null if `col` is equal to zero, or `col` otherwise.
*
* @group conditional_funcs
* @since 4.0.0
*/
def nullifzero(col: Column): Column = Column.fn("nullifzero", col)

/**
* Returns `col2` if `col1` is null, or `col1` otherwise.
*
Expand All @@ -7848,6 +7856,14 @@ object functions {
*/
def nvl2(col1: Column, col2: Column, col3: Column): Column = Column.fn("nvl2", col1, col2, col3)

/**
* Returns zero if `col` is null, or `col` otherwise.
*
* @group conditional_funcs
* @since 4.0.0
*/
def zeroifnull(col: Column): Column = Column.fn("zeroifnull", col)

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import java.sql.{Date, Timestamp}

import scala.util.Random

import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkRuntimeException}
import org.apache.spark.{QueryContextType, SPARK_DOC_ROOT, SparkException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
Expand Down Expand Up @@ -331,6 +331,66 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(nullif(lit(5), lit(5))), Seq(Row(null)))
}

test("nullifzero function") {
withTable("t") {
// Here we exercise a non-nullable, non-foldable column.
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
checkAnswer(df.select(nullifzero($"col")), Seq(Row(null)))
}
// Here we exercise invalid cases including types that do not support ordering.
val df = Seq((0)).toDF("a")
var expr = nullifzero(map(lit(1), lit("a")))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
parameters = Map(
"left" -> "\"MAP<INT, STRING>\"",
"right" -> "\"INT\"",
"sqlExpr" -> "\"(map(1, a) = 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "nullifzero",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = nullifzero(array(lit(1), lit(2)))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
parameters = Map(
"left" -> "\"ARRAY<INT>\"",
"right" -> "\"INT\"",
"sqlExpr" -> "\"(array(1, 2) = 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "nullifzero",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = nullifzero(Literal.create(20201231, DateType))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.BINARY_OP_DIFF_TYPES",
parameters = Map(
"left" -> "\"DATE\"",
"right" -> "\"INT\"",
"sqlExpr" -> "\"(DATE '+57279-02-03' = 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "nullifzero",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
}

test("nvl") {
val df = Seq[(Integer, Integer)]((null, 8)).toDF("a", "b")
checkAnswer(df.selectExpr("nvl(a, b)"), Seq(Row(8)))
Expand All @@ -349,6 +409,66 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null)))
}

test("zeroifnull function") {
withTable("t") {
// Here we exercise a non-nullable, non-foldable column.
sql("create table t(col int not null) using csv")
sql("insert into t values (0)")
val df = sql("select col from t")
checkAnswer(df.select(zeroifnull($"col")), Seq(Row(0)))
}
// Here we exercise invalid cases including types that do not support ordering.
val df = Seq((0)).toDF("a")
var expr = zeroifnull(map(lit(1), lit("a")))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
parameters = Map(
"functionName" -> "`coalesce`",
"dataType" -> "(\"MAP<INT, STRING>\" or \"INT\")",
"sqlExpr" -> "\"coalesce(map(1, a), 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "zeroifnull",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = zeroifnull(array(lit(1), lit(2)))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
parameters = Map(
"functionName" -> "`coalesce`",
"dataType" -> "(\"ARRAY<INT>\" or \"INT\")",
"sqlExpr" -> "\"coalesce(array(1, 2), 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "zeroifnull",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
expr = zeroifnull(Literal.create(20201231, DateType))
checkError(
intercept[AnalysisException](df.select(expr)),
errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
parameters = Map(
"functionName" -> "`coalesce`",
"dataType" -> "(\"DATE\" or \"INT\")",
"sqlExpr" -> "\"coalesce(DATE '+57279-02-03', 0)\""),
context = ExpectedContext(
contextType = QueryContextType.DataFrame,
fragment = "zeroifnull",
objectType = "",
objectName = "",
callSitePattern = "",
startIndex = 0,
stopIndex = 0))
}

test("misc md5 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
Expand Down