From d910141b5d3d86e557aa9ff7665346ac3b0f2865 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 May 2015 20:57:19 -0700 Subject: [PATCH] Updated rest of the files --- python/pyspark/sql/dataframe.py | 2 +- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 5 +++-- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cee804f5cc1f7..553a690297396 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1069,7 +1069,7 @@ def agg(self, *exprs): >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=2), Row(MIN(age)=5)] + [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 56f9cfae40580..003a620dcc8ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -135,8 +135,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) } /** - * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this - * class, the resulting [[DataFrame]] won't automatically include the grouping columns. + * Compute aggregates by specifying a series of aggregate columns. Note that this function by + * default retains the grouping columns in its output. To not retain grouping columns, set + * `spark.sql.retainGroupColumns` to false. * * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 386ac969f1e7d..7e30652b54c7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -102,7 +102,7 @@ private[sql] object StatFunctions extends Logging { /** Generate a table of frequencies for the elements of two columns. */ private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" - val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e6.toInt) + val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) if (counts.length == 1e6.toInt) { logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + "the pairs. Please try reducing the amount of distinct items in your columns.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 15e841c160bfa..0634bb95b41ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -62,7 +62,7 @@ class DataFrameSuite extends QueryTest { val df = Seq((1,(1,1))).toDF() checkAnswer( - df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), Row(1, 1) :: Nil) } @@ -127,7 +127,7 @@ class DataFrameSuite extends QueryTest { df2 .select('_1 as 'letter, 'number) .groupBy('letter) - .agg('letter, countDistinct('number)), + .agg(countDistinct('number)), Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil ) }