Skip to content

Commit

Permalink
Define functions for schema and fields
Browse files Browse the repository at this point in the history
Instead of using a list[list[list[]]], use specific constructors for schema and field objects.
  • Loading branch information
cafreeman authored and Davies Liu committed Apr 14, 2015
1 parent 2fe0a1a commit 0e2a94f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 31 deletions.
6 changes: 5 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ export("cacheTable",
"jsonRDD",
"loadDF",
"parquetFile",
"buildSchema",
"field",
"sql",
"table",
"tableNames",
Expand All @@ -179,4 +181,6 @@ export("cacheTable",
"uncacheTable")

export("print.structType",
"print.structField")
"print.structField",
"print.struct",
"print.field")
74 changes: 53 additions & 21 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ infer_type <- function(x) {
# StructType
types <- lapply(x, infer_type)
fields <- lapply(1:length(x), function(i) {
list(name = names[[i]], type = types[[i]], nullable = TRUE)
field(names[[i]], types[[i]], TRUE)
})
list(type = "struct", fields = fields)
do.call(buildSchema, fields)
}
} else if (length(x) > 1) {
list(type = "array", elementType = type, containsNull = TRUE)
Expand All @@ -67,19 +67,19 @@ infer_type <- function(x) {

#' dump the schema into JSON string
tojson <- function(x) {
if (is.list(x)) {
if (inherits(x, "struct")) {
# schema object
l <- paste(lapply(x, tojson), collapse = ", ")
paste('{\"type\":\"struct\", \"fields\":','[', l, ']}', sep = '')
} else if (inherits(x, "field")) {
# field object
names <- names(x)
if (!is.null(names)) {
items <- lapply(names, function(n) {
safe_n <- gsub('"', '\\"', n)
paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
})
d <- paste(items, collapse = ', ')
paste('{', d, '}', sep = '')
} else {
l <- paste(lapply(x, tojson), collapse = ', ')
paste('[', l, ']', sep = '')
}
items <- lapply(names, function(n) {
safe_n <- gsub('"', '\\"', n)
paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
})
d <- paste(items, collapse = ", ")
paste('{', d, '}', sep = '')
} else if (is.character(x)) {
paste('"', x, '"', sep = '')
} else if (is.logical(x)) {
Expand Down Expand Up @@ -134,7 +134,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
stop(paste("unexpected type:", class(data)))
}

if (is.null(schema) || is.null(names(schema))) {
if (is.null(schema) || (!inherits(schema, "struct") && is.null(names(schema)))) {
row <- first(rdd)
names <- if (is.null(schema)) {
names(row)
Expand All @@ -143,7 +143,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
}
if (is.null(names)) {
names <- lapply(1:length(row), function(x) {
paste("_", as.character(x), sep = "")
paste("_", as.character(x), sep = "")
})
}

Expand All @@ -159,14 +159,12 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {

types <- lapply(row, infer_type)
fields <- lapply(1:length(row), function(i) {
list(name = names[[i]], type = types[[i]], nullable = TRUE)
field(names[[i]], types[[i]], TRUE)
})
schema <- list(type = "struct", fields = fields)
schema <- do.call(buildSchema, fields)
}

stopifnot(class(schema) == "list")
stopifnot(schema$type == "struct")
stopifnot(class(schema$fields) == "list")
stopifnot(class(schema) == "struct")
schemaString <- tojson(schema)

jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
Expand Down Expand Up @@ -518,3 +516,37 @@ createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, .
sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options)
dataFrame(sdf)
}

buildSchema <- function(field, ...) {
fields <- list(field, ...)
if (!all(sapply(fields, inherits, "field"))) {
stop("All arguments must be Field objects.")
}

structure(fields, class = "struct")
}

print.struct <- function(x, ...) {
cat(sapply(x, function(field) { paste("|-", "name = \"", field$name,
"\", type = \"", field$type,
"\", nullable = ", field$nullable, "\n",
sep = "") })
, sep = "")
}

field <- function(name, type, nullable = TRUE) {
if (class(name) != "character") {
stop("Field name must be a string.")
}
if (class(type) != "character") {
stop("Field type must be a string.")
}
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
structure(list("name" = name, "type" = type, "nullable" = nullable), class = "field")
}

print.field <- function(x, ...) {
cat("name = \"", x$name, "\", type = \"", x$type, "\", nullable = ", x$nullable, sep = "")
}
27 changes: 18 additions & 9 deletions R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,27 @@ test_that("infer types", {
expect_equal(infer_type(list(1L, 2L)),
list(type = 'array', elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(list(a = 1L, b = "2")),
list(type = "struct",
fields = list(list(name = "a", type = "integer", nullable = TRUE),
list(name = "b", type = "string", nullable = TRUE))))
buildSchema(field(name = "a", type = "integer", nullable = TRUE),
field(name = "b", type = "string", nullable = TRUE)))
e <- new.env()
assign("a", 1L, envir = e)
expect_equal(infer_type(e),
list(type = "map", keyType = "string", valueType = "integer",
valueContainsNull = TRUE))
})

test_that("buildSchema and field", {
testField <- field("a", "string")
expect_true(inherits(testField, "field"))
expect_true(testField$name == "a")
expect_true(testField$nullable)

testSchema <- buildSchema(testField, field("b", "integer"))
expect_true(inherits(testSchema, "struct"))
expect_true(inherits(testSchema[[2]], "field"))
expect_true(testSchema[[1]]$type == "string")
})

test_that("create DataFrame from RDD", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- createDataFrame(sqlCtx, rdd, list("a", "b"))
Expand All @@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", {
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("_1", "_2"))

fields <- list(list(name = "a", type = "integer", nullable = TRUE),
list(name = "b", type = "string", nullable = TRUE))
schema <- list(type = "struct", fields = fields)
schema <- buildSchema(field(name = "a", type = "integer", nullable = TRUE),
field(name = "b", type = "string", nullable = TRUE))
df <- createDataFrame(sqlCtx, rdd, schema)
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("a", "b"))
Expand All @@ -94,9 +104,8 @@ test_that("toDF", {
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("_1", "_2"))

fields <- list(list(name = "a", type = "integer", nullable = TRUE),
list(name = "b", type = "string", nullable = TRUE))
schema <- list(type = "struct", fields = fields)
schema <- buildSchema(field(name = "a", type = "integer", nullable = TRUE),
field(name = "b", type = "string", nullable = TRUE))
df <- toDF(rdd, schema)
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("a", "b"))
Expand Down

0 comments on commit 0e2a94f

Please sign in to comment.