From 5486b2d5c445d1bbbbe1fd643ddd318f470266ae Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 23 Jun 2015 23:40:30 +0800 Subject: [PATCH] DataFrame API modification --- .../src/main/scala/org/apache/spark/sql/functions.scala | 6 +++--- .../scala/org/apache/spark/sql/MathExpressionsSuite.scala | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f6bd19bac61b2..694cf3b39b09d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1394,12 +1394,12 @@ object functions { def round(e: Column): Column = round(e.expr, 0) /** - * Returns the value of the given column `e` rounded to the value of `scale` decimal places. + * Returns the value of the given column rounded to 0 decimal places. * * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Column): Column = Round(e.expr, scale.expr) + def round(columnName: String): Column = round(Column(columnName), 0) /** * Returns the value of `e` rounded to `scale` decimal places. @@ -1407,7 +1407,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = round(e, lit(scale)) + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) /** * Returns the value of the given column rounded to `scale` decimal places. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index f8bbc5a032083..8ccfdd5147680 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -200,10 +200,9 @@ class MathExpressionsSuite extends QueryTest { test("round") { checkAnswer( - ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9), round(1809242.3151111344, 9)"), + ctx.sql("SELECT round(-32768), round(1809242.3151111344, 9)"), Seq((1, 2)).toDF().select( round(lit(-32768)), - round(lit(1809242.3151111344), lit(9)), round(lit(1809242.3151111344), 9)) ) }