Skip to content

Commit

Permalink
[SPARK-6117] [SQL] add describe function to DataFrame for summary sta…
Browse files Browse the repository at this point in the history
…tistics
  • Loading branch information
azagrebin committed Mar 17, 2015
1 parent e26db9b commit 9daf31e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
80 changes: 80 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,86 @@ class DataFrame private[sql](
select(colNames :_*)
}

/**
* Compute specified aggregations for given columns of this [[DataFrame]].
* Each row of the resulting [[DataFrame]] contains column with aggregation name
* and columns with aggregation results for each given column.
* The aggregations are described as a List of mappings of their name to function
* which generates aggregation expression from column name.
*
* Note: can process only simple aggregation expressions
* which can be parsed by spark [[SqlParser]]
*
* {{{
* val aggregations = List(
* "max" -> (col => s"max($col)"), // expression computes max
* "avg" -> (col => s"sum($col)/count($col)")) // expression computes average
* df.multipleAggExpr("summary", aggregations, "age", "height")
*
* // summary age height
* // max 92.0 192.0
* // avg 53.0 178.0
* }}}
*/
@scala.annotation.varargs
private def multipleAggExpr(
aggCol: String,
aggregations: List[(String, String => String)],
cols: String*): DataFrame = {

val sqlParser = new SqlParser()

def addAggNameCol(aggDF: DataFrame, aggName: String = "") =
aggDF.selectExpr(s"'$aggName' as $aggCol"::cols.toList:_*)

def unionWithNextAgg(aggSoFarDF: DataFrame, nextAgg: (String, String => String)) =
nextAgg match { case (aggName, colToAggExpr) =>
val nextAggDF = if (cols.nonEmpty) {
def colToAggCol(col: String) =
Column(sqlParser.parseExpression(colToAggExpr(col))).as(col)
val aggCols = cols.map(colToAggCol)
agg(aggCols.head, aggCols.tail:_*)
} else {
sqlContext.emptyDataFrame
}
val nextAggWithNameDF = addAggNameCol(nextAggDF, aggName)
aggSoFarDF.unionAll(nextAggWithNameDF)
}

val emptyAgg = addAggNameCol(this).limit(0)
aggregations.foldLeft(emptyAgg)(unionWithNextAgg)
}

/**
* Compute numerical statistics for given columns of this [[DataFrame]]:
* count, mean (avg), stddev (standard deviation), min, max.
* Each row of the resulting [[DataFrame]] contains column with statistic name
* and columns with statistic results for each given column.
* If no columns are given then computes for all numerical columns.
*
* {{{
* df.describe("age", "height")
*
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
* // stddev 11.6 15.7
* // min 18.0 163.0
* // max 92.0 192.0
* }}}
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {
val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols
val aggregations = List[(String, String => String)](
"count" -> (col => s"count($col)"),
"mean" -> (col => s"avg($col)"),
"stddev" -> (col => s"sqrt(avg($col*$col) - avg($col)*avg($col))"),
"min" -> (col => s"min($col)"),
"max" -> (col => s"max($col)"))
multipleAggExpr("summary", aggregations, numCols:_*)
}

/**
* Returns the first `n` rows.
* @group action
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,26 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}

test("describe") {
def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq

val describeAllCols = describeTestData.describe("age", "height")
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
checkAnswer(describeAllCols, describeResult)

val describeNoCols = describeTestData.describe()
assert(getSchemaAsSeq(describeNoCols) === Seq("summary", "age", "height"))
checkAnswer(describeNoCols, describeResult)

val describeOneCol = describeTestData.describe("age")
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )

val emptyDescription = describeTestData.limit(0).describe()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
}

test("apply on query results (SPARK-5462)") {
val df = testData.sqlContext.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,25 @@ object TestData {
Salary(1, 1000.0) :: Nil).toDF()
salary.registerTempTable("salary")

case class PersonToDescribe(name: String, age: Int, height: Double)
val describeTestData = TestSQLContext.sparkContext.parallelize(
PersonToDescribe("Bob", 16, 176) ::
PersonToDescribe("Alice", 32, 164) ::
PersonToDescribe("David", 60, 192) ::
PersonToDescribe("Amy", 24, 180) :: Nil).toDF()
val describeResult =
Row("count", 4.0, 4.0) ::
Row("mean", 33.0, 178.0) ::
Row("stddev", 16.583123951777, 10.0) ::
Row("min", 16.0, 164) ::
Row("max", 60.0, 192) :: Nil
val emptyDescribeResult =
Row("count", 0, 0) ::
Row("mean", null, null) ::
Row("stddev", null, null) ::
Row("min", null, null) ::
Row("max", null, null) :: Nil

case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit 9daf31e

Please sign in to comment.