diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7611f479a628b..819e9a24e5c0e 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -9,7 +9,8 @@ export("print.jobj") exportClasses("DataFrame") -exportMethods("cache", +exportMethods("arrange", + "cache", "collect", "columns", "count", @@ -20,6 +21,7 @@ exportMethods("cache", "explain", "filter", "first", + "group_by", "groupBy", "head", "insertInto", @@ -28,12 +30,15 @@ exportMethods("cache", "join", "limit", "orderBy", + "mutate", "names", "persist", "printSchema", "registerTempTable", + "rename", "repartition", "sampleDF", + "sample_frac", "saveAsParquetFile", "saveAsTable", "saveDF", @@ -42,7 +47,7 @@ exportMethods("cache", "selectExpr", "show", "showDF", - "sortDF", + "summarize", "take", "unionAll", "unpersist", @@ -72,6 +77,8 @@ exportMethods("abs", "max", "mean", "min", + "n", + "n_distinct", "rlike", "sqrt", "startsWith", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 354642e7bc307..8a9d2dd45c588 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -480,6 +480,7 @@ setMethod("distinct", #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction #' @rdname sampleDF +#' @aliases sample_frac #' @export #' @examples #'\dontrun{ @@ -501,6 +502,15 @@ setMethod("sampleDF", dataFrame(sdf) }) +#' @rdname sampleDF +#' @aliases sampleDF +setMethod("sample_frac", + signature(x = "DataFrame", withReplacement = "logical", + fraction = "numeric"), + function(x, withReplacement, fraction) { + sampleDF(x, withReplacement, fraction) + }) + #' Count #' #' Returns the number of rows in a DataFrame @@ -682,7 +692,8 @@ setMethod("toRDD", #' @param x a DataFrame #' @return a GroupedData #' @seealso GroupedData -#' @rdname DataFrame +#' @aliases group_by +#' @rdname groupBy #' @export #' @examples #' \dontrun{ @@ -705,12 +716,21 @@ setMethod("groupBy", groupedData(sgd) }) -#' Agg +#' @rdname groupBy +#' @aliases group_by +setMethod("group_by", + signature(x = "DataFrame"), + function(x, ...) { + groupBy(x, ...) + }) + +#' Summarize data across columns #' #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame #' @rdname DataFrame +#' @aliases summarize #' @export setMethod("agg", signature(x = "DataFrame"), @@ -718,6 +738,14 @@ setMethod("agg", agg(groupBy(x), ...) }) +#' @rdname DataFrame +#' @aliases agg +setMethod("summarize", + signature(x = "DataFrame"), + function(x, ...) { + agg(x, ...) + }) + ############################## RDD Map Functions ################################## # All of the following functions mirror the existing RDD map functions, # @@ -886,7 +914,7 @@ setMethod("select", signature(x = "DataFrame", col = "list"), function(x, col) { cols <- lapply(col, function(c) { - if (class(c)== "Column") { + if (class(c) == "Column") { c@jc } else { col(c)@jc @@ -946,6 +974,42 @@ setMethod("withColumn", select(x, x$"*", alias(col, colName)) }) +#' Mutate +#' +#' Return a new DataFrame with the specified columns added. +#' +#' @param x A DataFrame +#' @param col a named argument of the form name = col +#' @return A new DataFrame with the new columns added. +#' @rdname withColumn +#' @aliases withColumn +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) +#' names(newDF) # Will contain newCol, newCol2 +#' } +setMethod("mutate", + signature(x = "DataFrame"), + function(x, ...) { + cols <- list(...) + stopifnot(length(cols) > 0) + stopifnot(class(cols[[1]]) == "Column") + ns <- names(cols) + if (!is.null(ns)) { + for (n in ns) { + if (n != "") { + cols[[n]] <- alias(cols[[n]], n) + } + } + } + do.call(select, c(x, x$"*", cols)) + }) + #' WithColumnRenamed #' #' Rename an existing column in a DataFrame. @@ -977,9 +1041,47 @@ setMethod("withColumnRenamed", select(x, cols) }) +#' Rename +#' +#' Rename an existing column in a DataFrame. +#' +#' @param x A DataFrame +#' @param newCol A named pair of the form new_column_name = existing_column +#' @return A DataFrame with the column name changed. +#' @rdname withColumnRenamed +#' @aliases withColumnRenamed +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' newDF <- rename(df, col1 = df$newCol1) +#' } +setMethod("rename", + signature(x = "DataFrame"), + function(x, ...) { + renameCols <- list(...) + stopifnot(length(renameCols) > 0) + stopifnot(class(renameCols[[1]]) == "Column") + newNames <- names(renameCols) + oldNames <- lapply(renameCols, function(col) { + callJMethod(col@jc, "toString") + }) + cols <- lapply(columns(x), function(c) { + if (c %in% oldNames) { + alias(col(c), newNames[[match(c, oldNames)]]) + } else { + col(c) + } + }) + select(x, cols) + }) + setClassUnion("characterOrColumn", c("character", "Column")) -#' SortDF +#' Arrange #' #' Sort a DataFrame by the specified column(s). #' @@ -987,7 +1089,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param col Either a Column object or character vector indicating the field to sort on #' @param ... Additional sorting fields #' @return A DataFrame where all elements are sorted. -#' @rdname sortDF +#' @rdname arrange #' @export #' @examples #'\dontrun{ @@ -995,11 +1097,11 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' sqlCtx <- sparkRSQL.init(sc) #' path <- "path/to/file.json" #' df <- jsonFile(sqlCtx, path) -#' sortDF(df, df$col1) -#' sortDF(df, "col1") -#' sortDF(df, asc(df$col1), desc(abs(df$col2))) +#' arrange(df, df$col1) +#' arrange(df, "col1") +#' arrange(df, asc(df$col1), desc(abs(df$col2))) #' } -setMethod("sortDF", +setMethod("arrange", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col, ...) { if (class(col) == "character") { @@ -1013,12 +1115,12 @@ setMethod("sortDF", dataFrame(sdf) }) -#' @rdname sortDF +#' @rdname arrange #' @aliases orderBy,DataFrame,function-method setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { - sortDF(x, col) + arrange(x, col) }) #' Filter @@ -1026,7 +1128,7 @@ setMethod("orderBy", #' Filter the rows of a DataFrame according to a given condition. #' #' @param x A DataFrame to be sorted. -#' @param condition The condition to sort on. This may either be a Column expression +#' @param condition The condition to filter on. This may either be a Column expression #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. #' @rdname filter @@ -1106,6 +1208,7 @@ setMethod("join", #' #' Return a new DataFrame containing the union of rows in this DataFrame #' and another DataFrame. This is equivalent to `UNION ALL` in SQL. +#' Note that this does not remove duplicate rows across the two DataFrames. #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 95fb9ff0887b6..9a68445ab451a 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -131,6 +131,8 @@ createMethods() #' alias #' #' Set a new name for a column + +#' @rdname column setMethod("alias", signature(object = "Column"), function(object, data) { @@ -141,8 +143,12 @@ setMethod("alias", } }) +#' substr +#' #' An expression that returns a substring. #' +#' @rdname column +#' #' @param start starting position #' @param stop ending position setMethod("substr", signature(x = "Column"), @@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"), }) #' Casts the column to a different data type. +#' +#' @rdname column +#' #' @examples #' \dontrun{ #' cast(df$age, "string") @@ -173,8 +182,8 @@ setMethod("cast", #' Approx Count Distinct #' -#' Returns the approximate number of distinct items in a group. -#' +#' @rdname column +#' @return the approximate number of distinct items in a group. setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.95) { @@ -184,8 +193,8 @@ setMethod("approxCountDistinct", #' Count Distinct #' -#' returns the number of distinct items in a group. -#' +#' @rdname column +#' @return the number of distinct items in a group. setMethod("countDistinct", signature(x = "Column"), function(x, ...) { @@ -197,3 +206,18 @@ setMethod("countDistinct", column(jc) }) +#' @rdname column +#' @aliases countDistinct +setMethod("n_distinct", + signature(x = "Column"), + function(x, ...) { + countDistinct(x, ...) + }) + +#' @rdname column +#' @aliases count +setMethod("n", + signature(x = "Column"), + function(x) { + count(x) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 380e8ebe8c8f4..557128a419f19 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -380,6 +380,14 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #################### DataFrame Methods ######################## +#' @rdname agg +#' @export +setGeneric("agg", function (x, ...) { standardGeneric("agg") }) + +#' @rdname arrange +#' @export +setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) + #' @rdname schema #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) @@ -404,6 +412,10 @@ setGeneric("except", function(x, y) { standardGeneric("except") }) #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) +#' @rdname groupBy +#' @export +setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) + #' @rdname DataFrame #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) @@ -424,7 +436,11 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) -#' @rdname sortDF +#' @rdname withColumn +#' @export +setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) + +#' @rdname arrange #' @export setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) @@ -432,10 +448,21 @@ setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) +#' @rdname withColumnRenamed +#' @export +setGeneric("rename", function(x, ...) { standardGeneric("rename") }) + #' @rdname registerTempTable #' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) +#' @rdname sampleDF +#' @export +setGeneric("sample_frac", + function(x, withReplacement, fraction, seed) { + standardGeneric("sample_frac") + }) + #' @rdname sampleDF #' @export setGeneric("sampleDF", @@ -473,9 +500,9 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) -#' @rdname sortDF +#' @rdname agg #' @export -setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) +setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) # @rdname tojson # @export @@ -564,6 +591,14 @@ setGeneric("like", function(x, ...) { standardGeneric("like") }) #' @export setGeneric("lower", function(x) { standardGeneric("lower") }) +#' @rdname column +#' @export +setGeneric("n", function(x) { standardGeneric("n") }) + +#' @rdname column +#' @export +setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) + #' @rdname column #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 02237b3672d6b..5a7a8a2caba13 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -56,6 +56,7 @@ setMethod("show", "GroupedData", #' #' @param x a GroupedData #' @return a DataFrame +#' @rdname agg #' @export #' @examples #' \dontrun{ @@ -83,8 +84,6 @@ setMethod("count", #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' #' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum #' } -setGeneric("agg", function (x, ...) { standardGeneric("agg") }) - setMethod("agg", signature(x = "GroupedData"), function(x, ...) { @@ -112,6 +111,13 @@ setMethod("agg", dataFrame(sdf) }) +#' @rdname agg +#' @aliases agg +setMethod("summarize", + signature(x = "GroupedData"), + function(x, ...) { + agg(x, ...) + }) # sum/mean/avg/min/max methods <- c("sum", "mean", "avg", "min", "max") diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 7a42e289fcd9e..dbb535e245321 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -428,6 +428,10 @@ test_that("sampleDF on a DataFrame", { expect_true(inherits(sampled, "DataFrame")) sampled2 <- sampleDF(df, FALSE, 0.1) expect_true(count(sampled2) < 3) + + # Also test sample_frac + sampled3 <- sample_frac(df, FALSE, 0.1) + expect_true(count(sampled3) < 3) }) test_that("select operators", { @@ -533,6 +537,7 @@ test_that("column functions", { c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c) c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") + c5 <- n(c) + n_distinct(c) }) test_that("string operators", { @@ -557,6 +562,13 @@ test_that("group by", { expect_true(inherits(df2, "DataFrame")) expect_true(3 == count(df2)) + # Also test group_by, summarize, mean + gd1 <- group_by(df, "name") + expect_true(inherits(gd1, "GroupedData")) + df_summarized <- summarize(gd, mean_age = mean(df$age)) + expect_true(inherits(df_summarized, "DataFrame")) + expect_true(3 == count(df_summarized)) + df3 <- agg(gd, age = "sum") expect_true(inherits(df3, "DataFrame")) expect_true(3 == count(df3)) @@ -573,12 +585,12 @@ test_that("group by", { expect_true(3 == count(max(gd, "age"))) }) -test_that("sortDF() and orderBy() on a DataFrame", { +test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) - sorted <- sortDF(df, df$age) + sorted <- arrange(df, df$age) expect_true(collect(sorted)[1,2] == "Michael") - sorted2 <- sortDF(df, "name") + sorted2 <- arrange(df, "name") expect_true(collect(sorted2)[2,"age"] == 19) sorted3 <- orderBy(df, asc(df$age)) @@ -659,17 +671,17 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { writeLines(lines, jsonPath2) df2 <- loadDF(sqlCtx, jsonPath2, "json") - unioned <- sortDF(unionAll(df, df2), df$age) + unioned <- arrange(unionAll(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) expect_true(count(unioned) == 6) expect_true(first(unioned)$name == "Michael") - excepted <- sortDF(except(df, df2), desc(df$age)) + excepted <- arrange(except(df, df2), desc(df$age)) expect_true(inherits(unioned, "DataFrame")) expect_true(count(excepted) == 2) expect_true(first(excepted)$name == "Justin") - intersected <- sortDF(intersect(df, df2), df$age) + intersected <- arrange(intersect(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) expect_true(count(intersected) == 1) expect_true(first(intersected)$name == "Andy") @@ -687,6 +699,18 @@ test_that("withColumn() and withColumnRenamed()", { expect_true(columns(newDF2)[1] == "newerAge") }) +test_that("mutate() and rename()", { + df <- jsonFile(sqlCtx, jsonPath) + newDF <- mutate(df, newAge = df$age + 2) + expect_true(length(columns(newDF)) == 3) + expect_true(columns(newDF)[3] == "newAge") + expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + + newDF2 <- rename(df, newerAge = df$age) + expect_true(length(columns(newDF2)) == 2) + expect_true(columns(newDF2)[1] == "newerAge") +}) + test_that("saveDF() on DataFrame and works with parquetFile", { df <- jsonFile(sqlCtx, jsonPath) saveDF(df, parquetPath, "parquet", mode="overwrite") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 1728b0b8c910e..fae4bd0fd2994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -246,6 +246,22 @@ object functions { */ def last(columnName: String): Column = last(Column(columnName)) + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + */ + def mean(e: Column): Column = avg(e) + + /** + * Aggregate function: returns the average of the values in a group. + * Alias for avg. + * + * @group agg_funcs + */ + def mean(columnName: String): Column = avg(columnName) + /** * Aggregate function: returns the minimum value of the expression in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2ca8dccae574..cf590cbd5219c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -308,6 +308,11 @@ class DataFrameSuite extends QueryTest { testData2.agg(avg('a)), Row(2.0)) + // Also check mean + checkAnswer( + testData2.agg(mean('a)), + Row(2.0)) + checkAnswer( testData2.agg(avg('a), sumDistinct('a)), // non-partial Row(2.0, 6.0) :: Nil)