Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed May 28, 2020
1 parent 82e97e3 commit 1bdff95
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,11 @@ trait DivModLike extends BinaryArithmetic {

override def nullable: Boolean = true

final override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
final override def nullSafeEval(input1: Any, input2: Any): Any = {
if (input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
evalOperation(input1, input2)
}
evalOperation(input1, input2)
}
}

Expand Down Expand Up @@ -516,24 +510,18 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
override def nullSafeEval(input1: Any, input2: Any): Any = {
if (input2 == 0) {
null
} else {
val input1 = left.eval(input)
if (input1 == null) {
null
} else {
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
input1 match {
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])

// Do not check these expressions, because these expressions extend NullIntolerant
// and override the eval function.
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])

val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
.map(name => Utils.classForName(name))
.filterNot(classOf[NonSQLExpression].isAssignableFrom)

Expand All @@ -180,8 +175,9 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
superClass.getMethod("eval", classOf[InternalRow])
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
if (isEvalOverrode && isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
s"or add ${clazz.getName} in the ignoreSet of this test.")
fail(s"${clazz.getName} overrode the eval method and extended " +
s"${classOf[NullIntolerant].getSimpleName}, which may be incorrect. " +
s"You may need to override the nullSafeEval method.")
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
} else {
Expand Down

0 comments on commit 1bdff95

Please sign in to comment.