Skip to content

Commit

Permalink
address comments from davis
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-wang committed Jul 13, 2015
1 parent 7a6bdbb commit 0f1bff2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,17 @@ case class Least(children: Expression*) extends Expression {
val evalChildren = children.map(_.gen(ctx))
def updateEval(i: Int): String =
s"""
if (${ev.isNull} || (!${evalChildren(i).isNull} && ${
ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) {
${ev.isNull} = ${evalChildren(i).isNull};
if (!${evalChildren(i).isNull} && (${ev.isNull} ||
${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) {
${ev.isNull} = false;
${ev.primitive} = ${evalChildren(i).primitive};
}
"""
s"""
${evalChildren.map(_.code).mkString("\n")}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${(0 to children.length - 1).map(updateEval).mkString("\n")}
${(0 until children.length).map(updateEval).mkString("\n")}
"""
}
}
Expand Down Expand Up @@ -399,17 +399,17 @@ case class Greatest(children: Expression*) extends Expression {
val evalChildren = children.map(_.gen(ctx))
def updateEval(i: Int): String =
s"""
if (${ev.isNull} || (!${evalChildren(i).isNull} && ${
ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) {
${ev.isNull} = ${evalChildren(i).isNull};
if (!${evalChildren(i).isNull} && (${ev.isNull} ||
${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) {
${ev.isNull} = false;
${ev.primitive} = ${evalChildren(i).primitive};
}
"""
s"""
${evalChildren.map(_.code).mkString("\n")}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${(0 to children.length - 1).map(updateEval).mkString("\n")}
${(0 until children.length).map(updateEval).mkString("\n")}
"""
}
}
22 changes: 18 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,11 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def greatest(exprs: Column*): Column = Greatest(exprs.map(_.expr): _*)
def greatest(exprs: Column*): Column = if (exprs.length < 2) {
sys.error("GREATEST takes at least 2 parameters")
} else {
Greatest(exprs.map(_.expr): _*)
}

/**
* Returns the greatest value of the list of column names.
Expand All @@ -1088,8 +1092,11 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def greatest(columnName: String, columnNames: String*): Column =
def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
sys.error("GREATEST takes at least 2 parameters")
} else {
greatest((columnName +: columnNames).map(Column.apply): _*)
}

/**
* Computes hex value of the given column.
Expand Down Expand Up @@ -1197,7 +1204,11 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def least(exprs: Column*): Column = Least(exprs.map(_.expr): _*)
def least(exprs: Column*): Column = if (exprs.length < 2) {
sys.error("LEAST takes at least 2 parameters")
} else {
Least(exprs.map(_.expr): _*)
}

/**
* Returns the least value of the list of column names.
Expand All @@ -1206,8 +1217,11 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def least(columnName: String, columnNames: String*): Column =
def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
sys.error("LEAST takes at least 2 parameters")
} else {
least((columnName +: columnNames).map(Column.apply): _*)
}

/**
* Computes the natural logarithm of the given value.
Expand Down

0 comments on commit 0f1bff2

Please sign in to comment.