Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49306][SQL] Create new SQL functions 'zeroifnull' and 'nullifzero' #47817

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ object FunctionRegistry {
expression[Least]("least"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
expression[NullIfZero]("nullifzero"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
expression[PosExplode]("posexplode"),
Expand All @@ -384,6 +385,7 @@ object FunctionRegistry {
expression[Rand]("random", true, Some("3.0.0")),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[ZeroIfNull]("zeroifnull"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably add them into functions.scala and functions.py. could be done in a separate PR tho.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will add DataFrame support in a separate PR :)

CaseWhen.registryEntry,

// math functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,47 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression)
}
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns null if `expr` is equal to zero, or `expr` otherwise.",
examples = """
Examples:
> SELECT _FUNC_(0);
NULL
> SELECT _FUNC_(2);
2
""",
since = "4.0.0",
group = "conditional_funcs")
case class NullIfZero(input: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(input: Expression) = this(input, If(EqualTo(input, Literal(0)), Literal(null), input))

override def parameters: Seq[Expression] = Seq(input)

override protected def withNewChildInternal(newInput: Expression): Expression =
copy(replacement = newInput)
}

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns zero if `expr` is equal to null, or `expr` otherwise.",
examples = """
Examples:
> SELECT _FUNC_(NULL);
0
> SELECT _FUNC_(2);
2
""",
since = "4.0.0",
group = "conditional_funcs")
case class ZeroIfNull(input: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(input: Expression) = this(input, new Nvl(input, Literal(0)))

override def parameters: Seq[Expression] = Seq(input)

override protected def withNewChildInternal(newInput: Expression): Expression =
copy(replacement = newInput)
}

@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns `expr2` if `expr1` is null, or `expr1` otherwise.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@
| org.apache.spark.sql.catalyst.expressions.Now | now | SELECT now() | struct<now():timestamp> |
| org.apache.spark.sql.catalyst.expressions.NthValue | nth_value | SELECT a, b, nth_value(b, 2) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct<a:string,b:int,nth_value(b, 2) OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int> |
| org.apache.spark.sql.catalyst.expressions.NullIf | nullif | SELECT nullif(2, 2) | struct<nullif(2, 2):int> |
| org.apache.spark.sql.catalyst.expressions.NullIfZero | nullifzero | SELECT nullifzero(0) | struct<nullifzero(0):int> |
| org.apache.spark.sql.catalyst.expressions.Nvl | ifnull | SELECT ifnull(NULL, array('2')) | struct<ifnull(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.Nvl | nvl | SELECT nvl(NULL, array('2')) | struct<nvl(NULL, array(2)):array<string>> |
| org.apache.spark.sql.catalyst.expressions.Nvl2 | nvl2 | SELECT nvl2(NULL, 2, 1) | struct<nvl2(NULL, 2, 1):int> |
Expand Down Expand Up @@ -384,6 +385,7 @@
| org.apache.spark.sql.catalyst.expressions.XmlToStructs | from_xml | SELECT from_xml('<p><a>1</a><b>0.8</b></p>', 'a INT, b DOUBLE') | struct<from_xml(<p><a>1</a><b>0.8</b></p>):struct<a:int,b:double>> |
| org.apache.spark.sql.catalyst.expressions.XxHash64 | xxhash64 | SELECT xxhash64('Spark', array(123), 2) | struct<xxhash64(Spark, array(123), 2):bigint> |
| org.apache.spark.sql.catalyst.expressions.Year | year | SELECT year('2016-07-30') | struct<year(2016-07-30):int> |
| org.apache.spark.sql.catalyst.expressions.ZeroIfNull | zeroifnull | SELECT zeroifnull(NULL) | struct<zeroifnull(NULL):int> |
| org.apache.spark.sql.catalyst.expressions.ZipWith | zip_with | SELECT zip_with(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)) | struct<zip_with(array(1, 2, 3), array(a, b, c), lambdafunction(named_struct(y, namedlambdavariable(), x, namedlambdavariable()), namedlambdavariable(), namedlambdavariable())):array<struct<y:string,x:int>>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue | any_value | SELECT any_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<any_value(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile | approx_percentile | SELECT approx_percentile(col, array(0.5, 0.4, 0.1), 100) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<approx_percentile(col, array(0.5, 0.4, 0.1), 100):array<int>> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT}
import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT, SparkNumberFormatException}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -285,6 +285,48 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession {
assert(df.selectExpr("random(1)").collect() != null)
assert(df.select(random(lit(1))).collect() != null)
}

test("SPARK-49306 nullifzero and zeroifnull functions") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test suite is for misc functions mostly. Could you move expression tests to ConditionalExpressionSuite and end-to-end tests to conditional-functions.sql, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it turns out all these were end to end tests so I just moved them all to conditional-functions.sql.

val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.selectExpr("nullifzero(0)"), Row(null))
checkAnswer(df.selectExpr("nullifzero(cast(0 as tinyint))"), Row(null))
checkAnswer(df.selectExpr("nullifzero(cast(0 as bigint))"), Row(null))
checkAnswer(df.selectExpr("nullifzero('0')"), Row(null))
checkAnswer(df.selectExpr("nullifzero(0.0)"), Row(null))
checkAnswer(df.selectExpr("nullifzero(1)"), Row(1))
checkAnswer(df.selectExpr("nullifzero(null)"), Row(null))
var expr = "nullifzero('abc')"
checkError(
exception = intercept[SparkNumberFormatException] {
checkAnswer(df.selectExpr(expr), Row(null))
},
errorClass = "CAST_INVALID_INPUT",
parameters = Map(
"expression" -> "'abc'",
"sourceType" -> "\"STRING\"",
"targetType" -> "\"BIGINT\"",
"ansiConfig" -> "\"spark.sql.ansi.enabled\""
),
context = ExpectedContext("", "", 0, expr.length - 1, expr))

checkAnswer(df.selectExpr("zeroifnull(null)"), Row(0))
checkAnswer(df.selectExpr("zeroifnull(1)"), Row(1))
checkAnswer(df.selectExpr("zeroifnull(cast(1 as tinyint))"), Row(1))
checkAnswer(df.selectExpr("zeroifnull(cast(1 as bigint))"), Row(1))
expr = "zeroifnull('abc')"
checkError(
exception = intercept[SparkNumberFormatException] {
checkAnswer(df.selectExpr(expr), Row(null))
},
errorClass = "CAST_INVALID_INPUT",
parameters = Map(
"expression" -> "'abc'",
"sourceType" -> "\"STRING\"",
"targetType" -> "\"BIGINT\"",
"ansiConfig" -> "\"spark.sql.ansi.enabled\""
),
context = ExpectedContext("", "", 0, expr.length - 1, expr))
}
}

object ReflectClass {
Expand Down