Skip to content

Commit

Permalink
Added sqrt and abs to Spark SQL DSL
Browse files Browse the repository at this point in the history
  • Loading branch information
sarutak committed Nov 21, 2014
1 parent 90a6a46 commit 0396f89
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ package object dsl {
def max(e: Expression) = Max(e)
def upper(e: Expression) = Upper(e)
def lower(e: Expression) = Lower(e)
def sqrt(e: Expression) = Sqrt(e)
def abs(e: Expression) = Abs(e)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
// TODO more implicit class for literal?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types._
import scala.math.pow

case class UnaryMinus(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
Expand Down
69 changes: 69 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.NullType

/* Implicits */
import org.apache.spark.sql.catalyst.dsl._
Expand Down Expand Up @@ -282,4 +283,72 @@ class DslQuerySuite extends QueryTest {
(1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
)
}

test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key asc),
(1 to 100).map(n => Seq(math.sqrt(n)))
)

checkAnswer(
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
(1 to 100).map(n => Seq(math.sqrt(n), n))
)

checkAnswer(
testData.select(sqrt(Literal(null, NullType))),
(1 to 100).map(_ => Seq(null))
)
}

test("abs") {
checkAnswer(
testData.select(abs('key)).orderBy('key asc),
(1 to 100).map(n => Seq(n))
)

checkAnswer(
negativeData.select(abs('key)).orderBy('key desc),
(1 to 100).map(n => Seq(n))
)

checkAnswer(
testData.select(abs(Literal(null, NullType))),
(1 to 100).map(_ => Seq(null))
)
}

test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
)

checkAnswer(
testData.select(upper('value), 'key),
(1 to 100).map(n => Seq(n.toString, n))
)

checkAnswer(
testData.select(upper(Literal(null, NullType))),
(1 to 100).map(n => Seq(null))
)
}

test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
)

checkAnswer(
testData.select(lower('value), 'key),
(1 to 100).map(n => Seq(n.toString, n))
)

checkAnswer(
testData.select(lower(Literal(null, NullType))),
(1 to 100).map(n => Seq(null))
)
}
}
4 changes: 4 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ object TestData {
(1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
testData.registerTempTable("testData")

val negativeData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
negativeData.registerTempTable("negativeData")

case class LargeAndSmallInts(a: Int, b: Int)
val largeAndSmallInts =
TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit 0396f89

Please sign in to comment.