diff --git a/pkg/NAMESPACE b/pkg/NAMESPACE index b52d67acd62a6..2ab7b84b65c97 100644 --- a/pkg/NAMESPACE +++ b/pkg/NAMESPACE @@ -147,6 +147,7 @@ exportMethods("agg") export("cacheTable", "clearCache", + "createDataFrame", "createExternalTable", "dropTempTable", "jsonFile", @@ -157,6 +158,7 @@ export("cacheTable", "table", "tableNames", "tables", + "toDF", "uncacheTable") export("sparkRSQL.init", diff --git a/pkg/R/DataFrame.R b/pkg/R/DataFrame.R index 203110357ef25..f5a3c55bfde61 100644 --- a/pkg/R/DataFrame.R +++ b/pkg/R/DataFrame.R @@ -168,7 +168,7 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) setMethod("showDF", signature(x = "DataFrame"), function(x, numRows = 20) { - cat(callJMethod(x@sdf, "showString", numToInt(numRows))) + cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n") }) setMethod("show", "DataFrame", @@ -569,10 +569,8 @@ setMethod("collect", close(objRaw) col }) - colNames <- callJMethod(x@sdf, "columns") - names(cols) <- colNames - dfOut <- do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) - dfOut + names(cols) <- columns(x) + do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) }) #' Limit diff --git a/pkg/R/RDD.R b/pkg/R/RDD.R index b204d7e48d463..00f227c0df2d5 100644 --- a/pkg/R/RDD.R +++ b/pkg/R/RDD.R @@ -160,6 +160,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), callJMethod(prev_jrdd, "rdd"), serializedFuncArr, rdd@env$prev_serializedMode, + serializedMode, depsBin, packageNamesArr, as.character(.sparkREnv[["libname"]]), diff --git a/pkg/R/SQLContext.R b/pkg/R/SQLContext.R index 02a1a1ee5b22b..f5ca720b8d5c5 100644 --- a/pkg/R/SQLContext.R +++ b/pkg/R/SQLContext.R @@ -1,5 +1,184 @@ # SQLcontext.R: SQLContext-driven functions +#' infer the SQL type +infer_type <- function(x) { + if (is.null(x)) { + stop("can not infer type from NULL") + } + + # class of POSIXlt is c("POSIXlt" "POSIXt") + type <- switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) + + if (type == "map") { + stopifnot(length(x) > 0) + key <- ls(x)[[1]] + list(type = "map", + keyType = "string", + valueType = infer_type(get(key, x)), + valueContainsNull = TRUE) + } else if (type == "array") { + stopifnot(length(x) > 0) + names <- names(x) + if (is.null(names)) { + list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) + } else { + # StructType + types <- lapply(x, infer_type) + fields <- lapply(1:length(x), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + list(type = "struct", fields = fields) + } + } else if (length(x) > 1) { + list(type = "array", elementType = type, containsNull = TRUE) + } else { + type + } +} + +#' dump the schema into JSON string +tojson <- function(x) { + if (is.list(x)) { + names <- names(x) + if (!is.null(names)) { + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ', ') + paste('{', d, '}', sep = '') + } else { + l <- paste(lapply(x, tojson), collapse = ', ') + paste('[', l, ']', sep = '') + } + } else if (is.character(x)) { + paste('"', x, '"', sep = '') + } else if (is.logical(x)) { + if (x) "true" else "false" + } else { + stop(paste("unexpected type:", class(x))) + } +} + +#' Create a DataFrame from an RDD +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param sqlCtx A SQLContext +#' @param data An RDD or list or data.frame +#' @param schema a list of column names or named list (StructType), optional +#' @return an DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- createDataFrame(sqlCtx, rdd) +#' } + +# TODO(davies): support sampling and infer type from NA +createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { + if (is.data.frame(data)) { + # get the names of columns, they will be put into RDD + schema <- names(data) + n <- nrow(data) + m <- ncol(data) + # get rid of factor type + dropFactor <- function(x) { + if (is.factor(x)) { + as.character(x) + } else { + x + } + } + data <- lapply(1:n, function(i) { + lapply(1:m, function(j) { dropFactor(data[i,j]) }) + }) + } + if (is.list(data)) { + sc <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "getJavaSparkContext", sqlCtx) + rdd <- parallelize(sc, data) + } else if (inherits(data, "RDD")) { + rdd <- data + } else { + stop(paste("unexpected type:", class(data))) + } + + if (is.null(schema) || is.null(names(schema))) { + row <- first(rdd) + names <- if (is.null(schema)) { + names(row) + } else { + as.list(schema) + } + if (is.null(names)) { + names <- lapply(1:length(row), function(x) { + paste("_", as.character(x), sep = "") + }) + } + + types <- lapply(row, infer_type) + fields <- lapply(1:length(row), function(i) { + list(name = names[[i]], type = types[[i]], nullable = TRUE) + }) + schema <- list(type = "struct", fields = fields) + } + + stopifnot(class(schema) == "list") + stopifnot(schema$type == "struct") + stopifnot(class(schema$fields) == "list") + schemaString <- tojson(schema) + + jrdd <- getJRDD(lapply(rdd, function(x) x), "row") + srdd <- callJMethod(jrdd, "rdd") + sdf <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "createDF", + srdd, schemaString, sqlCtx) + dataFrame(sdf) +} + +#' toDF() +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#' } + +setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) + +setMethod("toDF", signature(x = "RDD"), + function(x, ...) { + sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + get(".sparkRHivesc", envir = .sparkREnv) + } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { + get(".sparkRSQLsc", envir = .sparkREnv) + } else { + stop("no SQL context available") + } + createDataFrame(sqlCtx, x, ...) + }) + #' Create a DataFrame from a JSON file. #' #' Loads a JSON file (one object per line), returning the result as a DataFrame diff --git a/pkg/R/deserialize.R b/pkg/R/deserialize.R index f8e4c4e630125..2500c072bdb50 100644 --- a/pkg/R/deserialize.R +++ b/pkg/R/deserialize.R @@ -9,6 +9,8 @@ # Double -> double # Long -> double # Array[Byte] -> raw +# Date -> Date +# Time -> POSIXct # # Array[T] -> list() # Object -> jobj @@ -26,10 +28,12 @@ readTypedObject <- function(con, type) { "b" = readBoolean(con), "d" = readDouble(con), "r" = readRaw(con), + "D" = readDate(con), + "t" = readTime(con), "l" = readList(con), "n" = NULL, "j" = getJobj(readString(con)), - stop("Unsupported type for deserialization")) + stop(paste("Unsupported type for deserialization", type))) } readString <- function(con) { @@ -54,6 +58,15 @@ readType <- function(con) { rawToChar(readBin(con, "raw", n = 1L)) } +readDate <- function(con) { + as.Date(readInt(con), origin = "1970-01-01") +} + +readTime <- function(con) { + t <- readDouble(con) + as.POSIXct(t, origin = "1970-01-01") +} + # We only support lists where all elements are of same type readList <- function(con) { type <- readType(con) @@ -107,11 +120,12 @@ readDeserializeRows <- function(inputCon) { # the number of rows varies, we put the readRow function in a while loop # that termintates when the next row is empty. data <- list() - numCols <- readInt(inputCon) - # We write a length for each row out - while(length(numCols) > 0 && numCols > 0) { - data[[length(data) + 1L]] <- readRow(inputCon, numCols) - numCols <- readInt(inputCon) + while(TRUE) { + row <- readRow(inputCon, numCols) + if (length(row) == 0) { + break + } + data[[length(data) + 1L]] <- row } data # this is a list of named lists now } @@ -122,28 +136,32 @@ readRowList <- function(obj) { # the numCols bytes inside the read function in order to correctly # deserialize the row. rawObj <- rawConnection(obj, "r+") - numCols <- SparkR:::readInt(rawObj) - rowOut <- SparkR:::readRow(rawObj, numCols) - close(rawObj) - rowOut + on.exit(close(rawObj)) + SparkR:::readRow(rawObj, numCols) } readRow <- function(inputCon, numCols) { - lapply(1:numCols, function(x) { - obj <- readObject(inputCon) - if (is.null(obj)) { - NA - } else { - obj - } - }) # each row is a list now + numCols <- readInt(inputCon) + if (length(numCols) > 0 && numCols > 0) { + lapply(1:numCols, function(x) { + obj <- readObject(inputCon) + if (is.null(obj)) { + NA + } else { + obj + } + }) # each row is a list now + } else { + list() + } } # Take a single column as Array[Byte] and deserialize it into an atomic vector readCol <- function(inputCon, numRows) { - sapply(1:numRows, function(x) { + # sapply can not work with POSIXlt + do.call(c, lapply(1:numRows, function(x) { value <- readObject(inputCon) # Replace NULL with NA so we can coerce to vectors if (is.null(value)) NA else value - }) # each column is an atomic vector now + })) } diff --git a/pkg/R/serialize.R b/pkg/R/serialize.R index 924d8e8e1e20b..22a462fac89b0 100644 --- a/pkg/R/serialize.R +++ b/pkg/R/serialize.R @@ -1,12 +1,15 @@ # Utility functions to serialize R objects so they can be read in Java. # Type mapping from R to Java -# +# +# NULL -> Void # integer -> Int # character -> String # logical -> Boolean # double, numeric -> Double # raw -> Array[Byte] +# Date -> Date +# POSIXct,POSIXlt -> Time # # list[T] -> Array[T], where T is one of above mentioned types # environment -> Map[String, T], where T is a native type @@ -16,10 +19,12 @@ writeObject <- function(con, object, writeType = TRUE) { # NOTE: In R vectors have same type as objects. So we don't support # passing in vectors as arrays and instead require arrays to be passed # as lists. + type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt") if (writeType) { - writeType(con, class(object)) + writeType(con, type) } - switch(class(object), + switch(type, + NULL = writeVoid(con), integer = writeInt(con, object), character = writeString(con, object), logical = writeBoolean(con, object), @@ -29,7 +34,14 @@ writeObject <- function(con, object, writeType = TRUE) { list = writeList(con, object), jobj = writeString(con, object$id), environment = writeEnv(con, object), - stop("Unsupported type for serialization")) + Date = writeDate(con, object), + POSIXlt = writeTime(con, object), + POSIXct = writeTime(con, object), + stop(paste("Unsupported type for serialization", type))) +} + +writeVoid <- function(con) { + # no value for NULL } writeString <- function(con, value) { @@ -55,6 +67,28 @@ writeRawSerialize <- function(outputCon, batch) { writeRaw(outputCon, outputSer) } +writeRowSerialize <- function(outputCon, rows) { + invisible(lapply(rows, function(r) { + bytes <- serializeRow(r) + writeRaw(outputCon, bytes) + })) +} + +serializeRow <- function(row) { + rawObj <- rawConnection(raw(0), "wb") + on.exit(close(rawObj)) + SparkR:::writeRow(rawObj, row) + rawConnectionValue(rawObj) +} + +writeRow <- function(con, row) { + numCols <- length(row) + writeInt(con, numCols) + for (i in 1:numCols) { + writeObject(con, row[[i]]) + } +} + writeRaw <- function(con, batch) { writeInt(con, length(batch)) writeBin(batch, con, endian = "big") @@ -62,6 +96,7 @@ writeRaw <- function(con, batch) { writeType <- function(con, class) { type <- switch(class, + NULL = "n", integer = "i", character = "c", logical = "b", @@ -71,7 +106,10 @@ writeType <- function(con, class) { list = "l", jobj = "j", environment = "e", - stop("Unsupported type for serialization")) + Date = "D", + POSIXlt = 't', + POSIXct = 't', + stop(paste("Unsupported type for serialization", class))) writeBin(charToRaw(type), con) } @@ -109,6 +147,14 @@ writeEnv <- function(con, env) { } } +writeDate <- function(con, date) { + writeInt(con, as.integer(date)) +} + +writeTime <- function(con, time) { + writeDouble(con, as.double(time)) +} + # Used to serialize in a list of objects where each # object can be of a different type. Serialization format is # for each object diff --git a/pkg/R/sparkR.R b/pkg/R/sparkR.R index 6135c4d9c20d2..cf80d1332c1bd 100644 --- a/pkg/R/sparkR.R +++ b/pkg/R/sparkR.R @@ -222,7 +222,6 @@ sparkR.init <- function( sparkRSQL.init <- function(jsc) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - cat("Re-using existing SparkSQL Context. Please restart R to create a new SparkSQL Context\n") return(get(".sparkRSQLsc", envir = .sparkREnv)) } @@ -247,7 +246,6 @@ sparkRSQL.init <- function(jsc) { sparkRHive.init <- function(jsc) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { - cat("Re-using existing HiveContext. Please restart R to create a new HiveContext\n") return(get(".sparkRHivesc", envir = .sparkREnv)) } diff --git a/pkg/inst/sparkR-submit b/pkg/inst/sparkR-submit index 9c451ab8e3712..68b7698d630d5 100755 --- a/pkg/inst/sparkR-submit +++ b/pkg/inst/sparkR-submit @@ -64,9 +64,11 @@ cat > /tmp/sparkR.profile << EOF .libPaths(c(paste(projecHome,"/..", sep=""), .libPaths())) require(SparkR) sc <- sparkR.init() + sqlCtx <- sparkRSQL.init(sc) assign("sc", sc, envir=.GlobalEnv) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) cat("\n Welcome to SparkR!") - cat("\n Spark context is available as sc\n") + cat("\n Spark context is available as sc, SQL Context is available as sqlCtx\n") } EOF R diff --git a/pkg/inst/tests/test_sparkSQL.R b/pkg/inst/tests/test_sparkSQL.R index c78baab4772ec..d8096ddcd9678 100644 --- a/pkg/inst/tests/test_sparkSQL.R +++ b/pkg/inst/tests/test_sparkSQL.R @@ -15,6 +15,127 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +test_that("infer types", { + expect_equal(infer_type(1L), "integer") + expect_equal(infer_type(1.0), "double") + expect_equal(infer_type("abc"), "string") + expect_equal(infer_type(TRUE), "boolean") + expect_equal(infer_type(as.Date("2015-03-11")), "date") + expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") + expect_equal(infer_type(c(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(1L, 2L)), + list(type = 'array', elementType = "integer", containsNull = TRUE)) + expect_equal(infer_type(list(a = 1L, b = "2")), + list(type = "struct", + fields = list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)))) + e <- new.env() + assign("a", 1L, envir = e) + expect_equal(infer_type(e), + list(type = "map", keyType = "string", valueType = "integer", + valueContainsNull = TRUE)) +}) + +test_that("create DataFrame from RDD", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- createDataFrame(sqlCtx, rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- createDataFrame(sqlCtx, rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("toDF", { + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) + df <- toDF(rdd, list("a", "b")) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("_1", "_2")) + + fields <- list(list(name = "a", type = "integer", nullable = TRUE), + list(name = "b", type = "string", nullable = TRUE)) + schema <- list(type = "struct", fields = fields) + df <- toDF(rdd, schema) + expect_true(inherits(df, "DataFrame")) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) + df <- toDF(rdd) + expect_true(inherits(df, "DataFrame")) + expect_true(count(df) == 10) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) +}) + +test_that("create DataFrame from list or data.frame", { + l <- list(list(1, 2), list(3, 4)) + df <- createDataFrame(sqlCtx, l, c("a", "b")) + expect_equal(columns(df), c("a", "b")) + + l <- list(list(a=1, b=2), list(a=3, b=4)) + df <- createDataFrame(sqlCtx, l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(sqlCtx, ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) +}) + +test_that("create DataFrame with different data types", { + l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), + f = as.POSIXct("2015-03-15 12:13:14.056")) + df <- createDataFrame(sqlCtx, list(l)) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), + c("d", "string"), c("e", "date"), c("f", "timestamp"))) + expect_equal(count(df), 1) + expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) +}) + +# TODO: enable this test after fix serialization for nested object +#test_that("create DataFrame with nested array and struct", { +# e <- new.env() +# assign("n", 3L, envir = e) +# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) +# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), +# c("c", "map"), c("d", "struct"))) +# expect_equal(count(df), 1) +# ldf <- collect(df) +# expect_equal(ldf[1,], l[[1]]) +#}) + test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) expect_true(inherits(df, "DataFrame")) @@ -424,7 +545,7 @@ test_that("sortDF() and orderBy() on a DataFrame", { sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) expect_true(collect(sorted3)[2, "age"] == 19) - + sorted4 <- orderBy(df, desc(df$name)) expect_true(first(sorted4)$name == "Michael") expect_true(collect(sorted4)[3,"name"] == "Andy") @@ -442,7 +563,7 @@ test_that("filter() on a DataFrame", { test_that("join() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) - + mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", "{\"name\":\"Justin\", \"test\": \"yes\"}", @@ -450,20 +571,20 @@ test_that("join() on a DataFrame", { jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) df2 <- jsonFile(sqlCtx, jsonPath2) - + joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) expect_true(count(joined) == 12) - + joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) expect_true(count(joined2) == 3) - + joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) expect_true(count(joined3) == 4) expect_true(is.na(collect(joined3)$age[4])) - + joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) @@ -491,24 +612,24 @@ test_that("isLocal()", { test_that("unionAll(), subtract(), and intersect() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) - + lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPath2) df2 <- loadDF(sqlCtx, jsonPath2, "json") - + unioned <- sortDF(unionAll(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) expect_true(count(unioned) == 6) expect_true(first(unioned)$name == "Michael") - + subtracted <- sortDF(subtract(df, df2), desc(df$age)) expect_true(inherits(unioned, "DataFrame")) expect_true(count(subtracted) == 2) expect_true(first(subtracted)$name == "Justin") - + intersected <- sortDF(intersect(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) expect_true(count(intersected) == 1) @@ -521,7 +642,7 @@ test_that("withColumn() and withColumnRenamed()", { expect_true(length(columns(newDF)) == 3) expect_true(columns(newDF)[3] == "newAge") expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) - + newDF2 <- withColumnRenamed(df, "age", "newerAge") expect_true(length(columns(newDF2)) == 2) expect_true(columns(newDF2)[1] == "newerAge") diff --git a/pkg/inst/worker/worker.R b/pkg/inst/worker/worker.R index aeb8ec8f37c68..2a61f2ec87478 100644 --- a/pkg/inst/worker/worker.R +++ b/pkg/inst/worker/worker.R @@ -25,11 +25,8 @@ splitIndex <- SparkR:::readInt(inputCon) execLen <- SparkR:::readInt(inputCon) execFunctionName <- unserialize(SparkR:::readRawLen(inputCon, execLen)) -# read the inputSerialization bit value -inputSerialization <- SparkR:::readString(inputCon) - -# read the isOutputSerialized bit flag -isOutputSerialized <- SparkR:::readInt(inputCon) +deserializer <- SparkR:::readString(inputCon) +serializer <- SparkR:::readString(inputCon) # Redirect stdout to stderr to prevent print statements from # interfering with outputStream @@ -68,27 +65,29 @@ isEmpty <- SparkR:::readInt(inputCon) if (isEmpty != 0) { if (numPartitions == -1) { - if (inputSerialization == "byte") { + if (deserializer == "byte") { # Now read as many characters as described in funcLen data <- SparkR:::readDeserialize(inputCon) - } else if (inputSerialization == "string") { + } else if (deserializer == "string") { data <- readLines(inputCon) - } else if (inputSerialization == "row") { + } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } output <- do.call(execFunctionName, list(splitIndex, data)) - if (isOutputSerialized) { + if (serializer == "byte") { SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) } else { SparkR:::writeStrings(outputCon, output) } } else { - if (inputSerialization == "byte") { + if (deserializer == "byte") { # Now read as many characters as described in funcLen data <- SparkR:::readDeserialize(inputCon) - } else if (inputSerialization == "string") { + } else if (deserializer == "string") { data <- readLines(inputCon) - } else if (inputSerialization == "row") { + } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } @@ -121,7 +120,7 @@ if (isEmpty != 0) { } # End of output -if (isOutputSerialized) { +if (serializer %in% c("byte", "row")) { SparkR:::writeInt(outputCon, 0L) } diff --git a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala index 8e2e3e40ab64f..832a81b9113a1 100644 --- a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala +++ b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala @@ -16,8 +16,8 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], numPartitions: Int, func: Array[Byte], - parentSerializedMode: String, - dataSerialized: Boolean, + deserializer: String, + serializer: String, functionDependencies: Array[Byte], packageNames: Array[Byte], rLibDir: String, @@ -127,18 +127,16 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( dataOut.writeInt(splitIndex) dataOut.writeInt(func.length) - dataOut.write(func, 0, func.length) + dataOut.write(func) - // R worker process input serialization flag - SerDe.writeString(dataOut, parentSerializedMode) - // R worker process output serialization flag - dataOut.writeInt(if (dataSerialized) 1 else 0) + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) dataOut.writeInt(packageNames.length) - dataOut.write(packageNames, 0, packageNames.length) + dataOut.write(packageNames) dataOut.writeInt(functionDependencies.length) - dataOut.write(functionDependencies, 0, functionDependencies.length) + dataOut.write(functionDependencies) dataOut.writeInt(broadcastVars.length) broadcastVars.foreach { broadcast => @@ -159,14 +157,14 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( } for (elem <- iter) { - if (parentSerializedMode == SerializationFormats.BYTE) { + if (deserializer == SerializationFormats.BYTE) { val elemArr = elem.asInstanceOf[Array[Byte]] dataOut.writeInt(elemArr.length) dataOut.write(elemArr, 0, elemArr.length) - } else if (parentSerializedMode == SerializationFormats.ROW) { + } else if (deserializer == SerializationFormats.ROW) { val rowArr = elem.asInstanceOf[Array[Byte]] dataOut.write(rowArr, 0, rowArr.length) - } else if (parentSerializedMode == SerializationFormats.STRING) { + } else if (deserializer == SerializationFormats.STRING) { printOut.println(elem) } } @@ -210,13 +208,14 @@ private class PairwiseRRDD[T: ClassTag]( parent: RDD[T], numPartitions: Int, hashFunc: Array[Byte], - parentSerializedMode: String, + deserializer: String, functionDependencies: Array[Byte], packageNames: Array[Byte], rLibDir: String, broadcastVars: Array[Object]) - extends BaseRRDD[T, (Int, Array[Byte])](parent, numPartitions, hashFunc, parentSerializedMode, - true, functionDependencies, packageNames, rLibDir, + extends BaseRRDD[T, (Int, Array[Byte])](parent, numPartitions, hashFunc, deserializer, + SerializationFormats.BYTE, functionDependencies, + packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { private var dataStream: DataInputStream = _ @@ -255,14 +254,15 @@ private class PairwiseRRDD[T: ClassTag]( private class RRDD[T: ClassTag]( parent: RDD[T], func: Array[Byte], - parentSerializedMode: String, + deserializer: String, + serializer: String, functionDependencies: Array[Byte], packageNames: Array[Byte], rLibDir: String, broadcastVars: Array[Object]) - extends BaseRRDD[T, Array[Byte]](parent, -1, func, parentSerializedMode, - true, functionDependencies, packageNames, rLibDir, - broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + extends BaseRRDD[T, Array[Byte]](parent, -1, func, deserializer, + serializer, functionDependencies, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { private var dataStream: DataInputStream = _ @@ -298,14 +298,14 @@ private class RRDD[T: ClassTag]( private class StringRRDD[T: ClassTag]( parent: RDD[T], func: Array[Byte], - parentSerializedMode: String, + deserializer: String, functionDependencies: Array[Byte], packageNames: Array[Byte], rLibDir: String, broadcastVars: Array[Object]) - extends BaseRRDD[T, String](parent, -1, func, parentSerializedMode, - false, functionDependencies, packageNames, rLibDir, - broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + extends BaseRRDD[T, String](parent, -1, func, deserializer, SerializationFormats.STRING, + functionDependencies, packageNames, rLibDir, + broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { private var dataStream: BufferedReader = _ diff --git a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SQLUtils.scala b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SQLUtils.scala index 2491027d6d7b7..8df90c160f675 100644 --- a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SQLUtils.scala +++ b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SQLUtils.scala @@ -1,20 +1,35 @@ package edu.berkeley.cs.amplab.sparkr -import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} +import edu.berkeley.cs.amplab.sparkr.SerDe._ + object SQLUtils { def createSQLContext(jsc: JavaSparkContext): SQLContext = { - new SQLContext(jsc.sc) + new SQLContext(jsc) + } + + def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { + new JavaSparkContext(sqlCtx.sparkContext) } def toSeq[T](arr: Array[T]): Seq[T] = { arr.toSeq } + def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val num = schema.fields.size + val rowRDD = rdd.map(bytesToRow) + sqlContext.createDataFrame(rowRDD, schema) + } + // A helper to include grouping columns in Agg() // TODO(davies): use internal API after merged into Spark def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { @@ -36,6 +51,15 @@ object SQLUtils { df.map(r => rowToRBytes(r)) } + private[this] def bytesToRow(bytes: Array[Byte]): Row = { + val bis = new ByteArrayInputStream(bytes) + val dis = new DataInputStream(bis) + val num = readInt(dis) + Row.fromSeq((0 until num).map { i => + readObject(dis) + }.toSeq) + } + private[this] def rowToRBytes(row: Row): Array[Byte] = { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) diff --git a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SerDe.scala b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SerDe.scala index 5e6cc0bf5e931..ea0115dd3e844 100644 --- a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SerDe.scala +++ b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SerDe.scala @@ -1,9 +1,9 @@ package edu.berkeley.cs.amplab.sparkr -import scala.collection.JavaConversions._ +import java.io.{DataInputStream, DataOutputStream} +import java.sql.{Date, Time} -import java.io.DataInputStream -import java.io.DataOutputStream +import scala.collection.JavaConversions._ /** * Utility functions to serialize, deserialize objects to / from R @@ -12,11 +12,14 @@ object SerDe { // Type mapping from R to Java // + // NULL -> void // integer -> Int // character -> String // logical -> Boolean // double, numeric -> Double // raw -> Array[Byte] + // Date -> Date + // POSIXlt/POSIXct -> Time // // list[T] -> Array[T], where T is one of above mentioned types // environment -> Map[String, T], where T is a native type @@ -35,6 +38,7 @@ object SerDe { dis: DataInputStream, dataType: Char): Object = { dataType match { + case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) @@ -42,6 +46,8 @@ object SerDe { case 'e' => readMap(dis) case 'r' => readBytes(dis) case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) case 'j' => JVMObjectTracker.getObject(readString(dis)) case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } @@ -77,6 +83,16 @@ object SerDe { if (intVal == 0) false else true } + def readDate(in: DataInputStream) = { + val d = in.readInt() + new Date(d.toLong * 24 * 3600 * 1000) + } + + def readTime(in: DataInputStream) = { + val t = in.readDouble() + new Time((t * 1000L).toLong) + } + def readBytesArr(in: DataInputStream) = { val len = readInt(in) (0 until len).map(_ => readBytes(in)).toArray @@ -142,6 +158,8 @@ object SerDe { // Double -> double // Long -> double // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct // // Array[T] -> list() // Object -> jobj @@ -153,6 +171,8 @@ object SerDe { case "double" => dos.writeByte('d') case "integer" => dos.writeByte('i') case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') case "list" => dos.writeByte('l') case "jobj" => dos.writeByte('j') @@ -180,6 +200,12 @@ object SerDe { case "boolean" | "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) @@ -234,6 +260,15 @@ object SerDe { out.writeInt(intValue) } + def writeDate(out: DataOutputStream, value: Date) { + out.writeInt((value.getTime / 1000 / 3600 / 24).toInt) + } + + def writeTime(out: DataOutputStream, value: Time) { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String) { val len = value.length diff --git a/sparkR b/sparkR index 2ea730fff4a59..ee25123041046 100755 --- a/sparkR +++ b/sparkR @@ -31,8 +31,10 @@ cat > /tmp/sparkR.profile << EOF require(SparkR) sc <- sparkR.init(Sys.getenv("MASTER", unset = "")) assign("sc", sc, envir=.GlobalEnv) + sqlCtx <- sparkRSQL.init(sc) + assign("sqlCtx", sqlCtx, envir=.GlobalEnv) cat("\n Welcome to SparkR!") - cat("\n Spark context is available as sc\n") + cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") } EOF