Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyum committed May 26, 2020
1 parent 2c943cc commit b127c41
Showing 1 changed file with 29 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ package org.apache.spark.sql.expressions
import scala.collection.parallel.immutable.ParVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{NonSQLExpression, _}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -159,73 +158,37 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
}
}

test("Check whether should extend NullIntolerant") {
// Only check expressions extended from these expressions
val parentExpressionNames = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression],
classOf[SeptenaryExpression]).map(_.getName)
// Do not check these expressions
val whiteList = Seq(
classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod],
classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet],
classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName)

spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => whiteList.exists(_.equals(c))).foreach { className =>
if (needToCheckNullIntolerant(className)) {
val evalExist = checkIfEvalOverrode(className)
val nullIntolerantExist = checkIfNullIntolerantMixedIn(className)
if (evalExist && nullIntolerantExist) {
fail(s"$className should not extend ${classOf[NullIntolerant].getSimpleName}")
} else if (!evalExist && !nullIntolerantExist) {
fail(s"$className should extend ${classOf[NullIntolerant].getSimpleName}")
} else {
assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist))
}
}
}
test("Check whether SQL expressions should extend NullIntolerant") {
// Only check expressions extended from these expressions because these expressions are
// NullIntolerant by default.
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])

def needToCheckNullIntolerant(className: String): Boolean = {
var clazz: Class[_] = Utils.classForName(className)
val isNonSQLExpr =
clazz.getInterfaces.exists(_.getName.equals(classOf[NonSQLExpression].getName))
var checkNullIntolerant: Boolean = false
while (!checkNullIntolerant && clazz.getSuperclass != null) {
checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName))
if (!checkNullIntolerant) {
clazz = clazz.getSuperclass
}
}
checkNullIntolerant && !isNonSQLExpr
}
// Do not check these expressions, because these expressions extend NullIntolerant
// and override the eval function.
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])

def checkIfNullIntolerantMixedIn(className: String): Boolean = {
val nullIntolerantName = classOf[NullIntolerant].getName
var clazz: Class[_] = Utils.classForName(className)
var nullIntolerantMixedIn = false
while (!nullIntolerantMixedIn && !parentExpressionNames.exists(_.equals(clazz.getName))) {
nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) ||
clazz.getInterfaces.exists { i =>
Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName))
}
if (!nullIntolerantMixedIn) {
clazz = clazz.getSuperclass
}
}
nullIntolerantMixedIn
}

def checkIfEvalOverrode(className: String): Boolean = {
var clazz: Class[_] = Utils.classForName(className)
var evalOverrode: Boolean = false
while (!evalOverrode && !parentExpressionNames.exists(_.equals(clazz.getName))) {
evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals("eval"))
if (!evalOverrode) {
clazz = clazz.getSuperclass
val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
.map(name => Utils.classForName(name))
.filterNot(classOf[NonSQLExpression].isAssignableFrom)

exprTypesToCheck.foreach { superClass =>
candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>
val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) !=
superClass.getMethod("eval", classOf[InternalRow])
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
if (isEvalOverrode && isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
s"or add ${clazz.getName} in the ignoreSet of this test.")
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
} else {
assert((!isEvalOverrode && isNullIntolerantMixedIn) ||
(isEvalOverrode && !isNullIntolerantMixedIn))
}
}
evalOverrode
}
}
}

0 comments on commit b127c41

Please sign in to comment.