From 9be894efc4cbba8d0f72c316a23847a9da49f8af Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 14:05:44 +0800 Subject: [PATCH] add round functions in o.a.s.sql.functions --- .../org/apache/spark/sql/functions.scala | 24 +++++++++++++++++++ .../spark/sql/MathExpressionsSuite.scala | 10 ++++++++ 2 files changed, 34 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ffa52f62588dc..cbae58b303088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1385,6 +1385,30 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = Round(Seq(e.expr)) + + /** + * Returns the value of `e` rounded to the value of `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Column): Column = Round(Seq(e.expr, scale.expr)) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = round(e, lit(scale)) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 24bef21b999ea..f8bbc5a032083 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,16 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + checkAnswer( + ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9), round(1809242.3151111344, 9)"), + Seq((1, 2)).toDF().select( + round(lit(-32768)), + round(lit(1809242.3151111344), lit(9)), + round(lit(1809242.3151111344), 9)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) }