diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 518901d142f56..be0f904dc14d9 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1838,7 +1838,7 @@ class DataFrame(object): department = sqlContext.parquetFile("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ - .groupby(department.name, "gender").agg({"salary": "avg", "age": "max"}) + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ def __init__(self, jdf, sql_ctx): @@ -2178,13 +2178,13 @@ def filter(self, condition): where = filter - def groupby(self, *cols): + def groupBy(self, *cols): """ Group the [[DataFrame]] using the specified columns, so we can run aggregation on them. See :class:`GroupedDataFrame` for all the available aggregate functions:: - df.groupby(df.department).avg() - df.groupby("department", "gender").agg({ + df.groupBy(df.department).avg() + df.groupBy("department", "gender").agg({ "salary": "avg", "age": "max", }) @@ -2194,16 +2194,16 @@ def groupby(self, *cols): else: cols = [c._jc for c in cols] jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) - jdf = self._jdf.groupby(self._jdf.toColumnArray(jcols)) + jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): """ Aggregate on the entire [[DataFrame]] without groups - (shorthand for df.groupby.agg()):: + (shorthand for df.groupBy.agg()):: df.agg({"age": "max", "salary": "avg"}) """ - return self.groupby().agg(*exprs) + return self.groupBy().agg(*exprs) def unionAll(self, other): """ Return a new DataFrame containing union of rows in this @@ -2266,7 +2266,7 @@ class GroupedDataFrame(object): """ A set of methods for aggregations on a :class:`DataFrame`, - created by DataFrame.groupby(). + created by DataFrame.groupBy(). """ def __init__(self, jdf, sql_ctx): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c8df2fc6ef956..eb48102229837 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -986,7 +986,7 @@ def test_column_select(self): def test_aggregator(self): from pyspark.sql import Aggregator as Agg df = self.df - g = df.groupby() + g = df.groupBy() self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 72145eb1cdf56..a6c17ad7468ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -77,7 +77,7 @@ import org.apache.spark.util.Utils * * people.filter("age" > 30) * .join(department, people("deptId") === department("id")) - * .groupby(department("name"), "gender") + * .groupBy(department("name"), "gender") * .agg(avg(people("salary")), max(people("age"))) * }}} */ @@ -331,17 +331,17 @@ class DataFrame protected[sql]( * * {{{ * // Compute the average for all numeric columns grouped by department. - * df.groupby($"department").avg() + * df.groupBy($"department").avg() * * // Compute the max age and average salary, grouped by department and gender. - * df.groupby($"department", $"gender").agg(Map( + * df.groupBy($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} */ @scala.annotation.varargs - override def groupby(cols: Column*): GroupedDataFrame = { + override def groupBy(cols: Column*): GroupedDataFrame = { new GroupedDataFrame(this, cols.map(_.expr)) } @@ -349,22 +349,22 @@ class DataFrame protected[sql]( * Group the [[DataFrame]] using the specified columns, so we can run aggregation on them. * See [[GroupedDataFrame]] for all the available aggregate functions. * - * This is a variant of groupby that can only group by existing columns using column names + * This is a variant of groupBy that can only group by existing columns using column names * (i.e. cannot construct expressions). * * {{{ * // Compute the average for all numeric columns grouped by department. - * df.groupby("department").avg() + * df.groupBy("department").avg() * * // Compute the max age and average salary, grouped by department and gender. - * df.groupby($"department", $"gender").agg(Map( + * df.groupBy($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} */ @scala.annotation.varargs - override def groupby(col1: String, cols: String*): GroupedDataFrame = { + override def groupBy(col1: String, cols: String*): GroupedDataFrame = { val colNames: Seq[String] = col1 +: cols new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) } @@ -372,23 +372,23 @@ class DataFrame protected[sql]( /** * Aggregate on the entire [[DataFrame]] without groups. * {{ - * // df.agg(...) is a shorthand for df.groupby().agg(...) + * // df.agg(...) is a shorthand for df.groupBy().agg(...) * df.agg(Map("age" -> "max", "salary" -> "avg")) - * df.groupby().agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }} */ - override def agg(exprs: Map[String, String]): DataFrame = groupby().agg(exprs) + override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) /** * Aggregate on the entire [[DataFrame]] without groups. * {{ - * // df.agg(...) is a shorthand for df.groupby().agg(...) + * // df.agg(...) is a shorthand for df.groupBy().agg(...) * df.agg(max($"age"), avg($"salary")) - * df.groupby().agg(max($"age"), avg($"salary")) + * df.groupBy().agg(max($"age"), avg($"salary")) * }} */ @scala.annotation.varargs - override def agg(expr: Column, exprs: Column*): DataFrame = groupby().agg(expr, exprs :_*) + override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) /** * Return a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -484,7 +484,14 @@ class DataFrame protected[sql]( /** * Return the number of rows in the [[DataFrame]]. */ - override def count(): Long = groupby().count().rdd.collect().head.getLong(0) + override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + + /** + * Return a new [[DataFrame]] that has exactly `numPartitions` partitions. + */ + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.applySchema(rdd.repartition(numPartitions), schema) + } override def persist(): this.type = { sqlContext.cacheQuery(this) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index 2e1ef7cf976ef..1f1e9bd9899f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate /** - * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupby]]. + * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. */ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) extends GroupedDataFrameApi { @@ -62,7 +62,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupby("department").agg(Map( + * df.groupBy("department").agg(Map( * "age" -> "max" * "sum" -> "expense" * )) @@ -80,7 +80,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupby("department").agg(Map( + * df.groupBy("department").agg(Map( * "age" -> "max" * "sum" -> "expense" * )) @@ -96,7 +96,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * import org.apache.spark.sql.dsl._ - * df.groupby("department").agg(max($"age"), sum($"expense")) + * df.groupBy("department").agg(max($"age"), sum($"expense")) * }}} */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala index f934eac779284..0d71c42de4e61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -54,6 +54,7 @@ trait RDDApi[T] { def count(): Long + def repartition(numPartitions: Int): DataFrame } @@ -97,10 +98,10 @@ trait DataFrameSpecificApi { def where(condition: Column): DataFrame @scala.annotation.varargs - def groupby(cols: Column*): GroupedDataFrame + def groupBy(cols: Column*): GroupedDataFrame @scala.annotation.varargs - def groupby(col1: String, cols: String*): GroupedDataFrame + def groupBy(col1: String, cols: String*): GroupedDataFrame def agg(exprs: Map[String, String]): DataFrame diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 3bdeb1871aca1..a03815c2be8bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -42,11 +42,11 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupby("a").agg($"a", sum($"b")), + testData2.groupBy("a").agg($"a", sum($"b")), Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( - testData2.groupby("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), + testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) checkAnswer( @@ -205,12 +205,12 @@ class DslQuerySuite extends QueryTest { test("null count") { checkAnswer( - testData3.groupby('a).agg('a, count('b)), + testData3.groupBy('a).agg('a, count('b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.groupby('a).agg('a, count('a + 'b)), + testData3.groupBy('a).agg('a, count('a + 'b)), Seq(Row(1,0), Row(2, 1)) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 385cd2f6f8970..be5e63c76f42e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -42,7 +42,7 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupby('value).agg(count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } @@ -50,14 +50,14 @@ class PlannerSuite extends FunSuite { } test("count distinct is partially aggregated") { - val query = testData.groupby('value).agg(countDistinct('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } test("mixed aggregates are partially aggregated") { val query = - testData.groupby('value).agg(count('value), countDistinct('key)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) }