diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a354cdce74afa..450aacfb51718 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -171,6 +171,8 @@ export("cacheTable", "jsonRDD", "loadDF", "parquetFile", + "buildSchema", + "field", "sql", "table", "tableNames", @@ -179,4 +181,6 @@ export("cacheTable", "uncacheTable") export("print.structType", - "print.structField") + "print.structField", + "print.struct", + "print.field") diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 930ada22f4c38..65057cc45a2f6 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -54,9 +54,9 @@ infer_type <- function(x) { # StructType types <- lapply(x, infer_type) fields <- lapply(1:length(x), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + field(names[[i]], types[[i]], TRUE) }) - list(type = "struct", fields = fields) + do.call(buildSchema, fields) } } else if (length(x) > 1) { list(type = "array", elementType = type, containsNull = TRUE) @@ -67,19 +67,19 @@ infer_type <- function(x) { #' dump the schema into JSON string tojson <- function(x) { - if (is.list(x)) { + if (inherits(x, "struct")) { + # schema object + l <- paste(lapply(x, tojson), collapse = ", ") + paste('{\"type\":\"struct\", \"fields\":','[', l, ']}', sep = '') + } else if (inherits(x, "field")) { + # field object names <- names(x) - if (!is.null(names)) { - items <- lapply(names, function(n) { - safe_n <- gsub('"', '\\"', n) - paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') - }) - d <- paste(items, collapse = ', ') - paste('{', d, '}', sep = '') - } else { - l <- paste(lapply(x, tojson), collapse = ', ') - paste('[', l, ']', sep = '') - } + items <- lapply(names, function(n) { + safe_n <- gsub('"', '\\"', n) + paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') + }) + d <- paste(items, collapse = ", ") + paste('{', d, '}', sep = '') } else if (is.character(x)) { paste('"', x, '"', sep = '') } else if (is.logical(x)) { @@ -134,7 +134,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { stop(paste("unexpected type:", class(data))) } - if (is.null(schema) || is.null(names(schema))) { + if (is.null(schema) || (!inherits(schema, "struct") && is.null(names(schema)))) { row <- first(rdd) names <- if (is.null(schema)) { names(row) @@ -143,7 +143,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { } if (is.null(names)) { names <- lapply(1:length(row), function(x) { - paste("_", as.character(x), sep = "") + paste("_", as.character(x), sep = "") }) } @@ -159,14 +159,12 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { types <- lapply(row, infer_type) fields <- lapply(1:length(row), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + field(names[[i]], types[[i]], TRUE) }) - schema <- list(type = "struct", fields = fields) + schema <- do.call(buildSchema, fields) } - stopifnot(class(schema) == "list") - stopifnot(schema$type == "struct") - stopifnot(class(schema$fields) == "list") + stopifnot(class(schema) == "struct") schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") @@ -518,3 +516,37 @@ createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, . sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) dataFrame(sdf) } + +buildSchema <- function(field, ...) { + fields <- list(field, ...) + if (!all(sapply(fields, inherits, "field"))) { + stop("All arguments must be Field objects.") + } + + structure(fields, class = "struct") +} + +print.struct <- function(x, ...) { + cat(sapply(x, function(field) { paste("|-", "name = \"", field$name, + "\", type = \"", field$type, + "\", nullable = ", field$nullable, "\n", + sep = "") }) + , sep = "") +} + +field <- function(name, type, nullable = TRUE) { + if (class(name) != "character") { + stop("Field name must be a string.") + } + if (class(type) != "character") { + stop("Field type must be a string.") + } + if (class(nullable) != "logical") { + stop("nullable must be either TRUE or FALSE") + } + structure(list("name" = name, "type" = type, "nullable" = nullable), class = "field") +} + +print.field <- function(x, ...) { + cat("name = \"", x$name, "\", type = \"", x$type, "\", nullable = ", x$nullable, sep = "") +} diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index cf5cf6d1692af..bf2101e276df6 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -44,9 +44,8 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(a = 1L, b = "2")), - list(type = "struct", - fields = list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)))) + buildSchema(field(name = "a", type = "integer", nullable = TRUE), + field(name = "b", type = "string", nullable = TRUE))) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -54,6 +53,18 @@ test_that("infer types", { valueContainsNull = TRUE)) }) +test_that("buildSchema and field", { + testField <- field("a", "string") + expect_true(inherits(testField, "field")) + expect_true(testField$name == "a") + expect_true(testField$nullable) + + testSchema <- buildSchema(testField, field("b", "integer")) + expect_true(inherits(testSchema, "struct")) + expect_true(inherits(testSchema[[2]], "field")) + expect_true(testSchema[[1]]$type == "string") +}) + test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlCtx, rdd, list("a", "b")) @@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- buildSchema(field(name = "a", type = "integer", nullable = TRUE), + field(name = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlCtx, rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -94,9 +104,8 @@ test_that("toDF", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- buildSchema(field(name = "a", type = "integer", nullable = TRUE), + field(name = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b"))