diff --git a/pkg/NAMESPACE b/pkg/NAMESPACE index 18eaa752f3103..f72a2c103ce32 100644 --- a/pkg/NAMESPACE +++ b/pkg/NAMESPACE @@ -100,6 +100,7 @@ exportMethods("columns", "schema", "sortDF", "select", + "selectExpr", "toRDD", "where") diff --git a/pkg/R/DataFrame.R b/pkg/R/DataFrame.R index 8da89e54f00bf..704470cd5210c 100644 --- a/pkg/R/DataFrame.R +++ b/pkg/R/DataFrame.R @@ -607,6 +607,36 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), dataFrame(sdf) }) +#' SelectExpr +#' +#' Select from a DataFrame using a set of SQL expressions. +#' +#' @param x A DataFrame to be selected from. +#' @param expr A string containing a SQL expression +#' @param ... Additional expressions +#' @return A DataFrame +#' @rdname selectExpr +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' selectExpr(df, "col1", "(col2 * 5) as newCol") +#' } +setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) + +#' @rdname selectExpr +#' @export +setMethod("selectExpr", + signature(x = "DataFrame", expr = "character"), + function(x, expr, ...) { + exprList <- list(expr, ...) + sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + dataFrame(sdf) + }) + #' SortDF #' #' Sort a DataFrame by the specified column(s). diff --git a/pkg/inst/tests/test_sparkSQL.R b/pkg/inst/tests/test_sparkSQL.R index 368660da0d75e..35dc94509d7f7 100644 --- a/pkg/inst/tests/test_sparkSQL.R +++ b/pkg/inst/tests/test_sparkSQL.R @@ -260,6 +260,17 @@ test_that("select with column", { expect_true(count(df2) == 3) }) +test_that("selectExpr() on a DataFrame", { + df <- jsonFile(sqlCtx, jsonPath) + selected <- selectExpr(df, "age * 2") + expect_true(names(selected) == "(age * 2)") + expect_equal(collect(selected), collect(select(df, df$age * 2L))) + + selected2 <- selectExpr(df, "name as newName", "abs(age) as age") + expect_equal(names(selected2), c("newName", "age")) + expect_true(count(selected2) == 3) +}) + test_that("column calculation", { df <- jsonFile(sqlCtx, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2")))