diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index dd9332ada80dd..41398ff956edd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] { * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its * child is foldable. */ - // TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs. def foldable: Boolean = false def nullable: Boolean def references: Set[Attribute] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e11155539fd75..520101802e25b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -94,25 +94,9 @@ object ConstantFolding extends Rule[LogicalPlan] { case q: LogicalPlan => q transformExpressionsDown { // Skip redundant folding of literals. case l: Literal => l - case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue - case e @ In(Literal(v, _), list) if(list.exists(c => c match { - case Literal(candidate, _) if(candidate == v) => true - case _ => false - })) => Literal(true, BooleanType) - case e if e.foldable => Literal(e.eval(null), e.dataType) - } - } -} - -/** - * The expression may be constant value, due to one or more of its children expressions is null or - * not null constantly, replaces [[catalyst.expressions.Expression Expressions]] with equivalent - * [[catalyst.expressions.Literal Literal]] values if possible caused by that. - */ -object NullPropagation extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { - case l: Literal => l + case e @ Count(Literal(null, _)) => Literal(null, e.dataType) + case e @ Sum(Literal(null, _)) => Literal(null, e.dataType) + case e @ Average(Literal(null, _)) => Literal(null, e.dataType) case e @ IsNull(Literal(null, _)) => Literal(true, BooleanType) case e @ IsNull(Literal(_, _)) => Literal(false, BooleanType) case e @ IsNull(c @ Rand) => Literal(false, BooleanType) @@ -135,6 +119,11 @@ object NullPropagation extends Rule[LogicalPlan] { Coalesce(newChildren) } } + case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue + case e @ In(Literal(v, _), list) if(list.exists(c => c match { + case Literal(candidate, _) if(candidate == v) => true + case _ => false + })) => Literal(true, BooleanType) // TODO put exceptional cases(Unary & Binary Expression) before here. case e: UnaryExpression => e.child match { case Literal(null, _) => Literal(null, e.dataType) @@ -143,6 +132,7 @@ object NullPropagation extends Rule[LogicalPlan] { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) } + case e if e.foldable => Literal(e.eval(null), e.dataType) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index a09270eb7b134..a4e83c8357439 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.exec.UDF import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ @@ -213,6 +214,16 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) @transient protected lazy val returnInspector = function.initialize(argumentInspectors.toArray) + + @transient + protected lazy val isUDFDeterministic = { + val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + (udfType != null && udfType.deterministic()) + } + + override def foldable = { + isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable) + } val dataType: DataType = inspectorToDataType(returnInspector)