Skip to content

Commit

Permalink
refactor schema functions
Browse files Browse the repository at this point in the history
Refactored `structType` and `structField` so that they can be used to create schemas from R for use with `createDataFrame`.

Moved everything to `schema.R`

Added new methods to `SQLUtils.scala` for handling `StructType` and `StructField` on the JVM side
  • Loading branch information
cafreeman authored and Davies Liu committed Apr 14, 2015
1 parent 40338a4 commit be5d5c1
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 100 deletions.
97 changes: 0 additions & 97 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,6 @@ infer_type <- function(x) {
}
}

#' dump the schema into JSON string
tojson <- function(x) {
if (inherits(x, "struct")) {
# schema object
l <- paste(lapply(x, tojson), collapse = ", ")
paste('{\"type\":\"struct\", \"fields\":','[', l, ']}', sep = '')
} else if (inherits(x, "field")) {
# field object
names <- names(x)
items <- lapply(names, function(n) {
safe_n <- gsub('"', '\\"', n)
paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
})
d <- paste(items, collapse = ", ")
paste('{', d, '}', sep = '')
} else if (is.character(x)) {
paste('"', x, '"', sep = '')
} else if (is.logical(x)) {
if (x) "true" else "false"
} else {
stop(paste("unexpected type:", class(x)))
}
}

#' Create a DataFrame from an RDD
#'
#' Converts an RDD to a DataFrame by infer the types.
Expand Down Expand Up @@ -516,76 +492,3 @@ createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, .
sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options)
dataFrame(sdf)
}

#' Create a Schema object
#'
#' Create an object of type "struct" that contains the metadata for a DataFrame. Intended for
#' use with createDataFrame and toDF.
#'
#' @param field a Field object (created with the field() function)
#' @param ... additional Field objects
#' @return a Schema object
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
#' schema <- buildSchema(field("a", "integer"), field("b", "string"))
#' df <- createDataFrame(sqlCtx, rdd, schema)
#' }
buildSchema <- function(field, ...) {
fields <- list(field, ...)
if (!all(sapply(fields, inherits, "field"))) {
stop("All arguments must be Field objects.")
}

structure(fields, class = "struct")
}

# print method for "struct" object
print.struct <- function(x, ...) {
cat(sapply(x, function(field) { paste("|-", "name = \"", field$name,
"\", type = \"", field$type,
"\", nullable = ", field$nullable, "\n",
sep = "") })
, sep = "")
}

#' Create a Field object
#'
#' Create a Field object that contains the metadata for a single field in a schema.
#'
#' @param name The name of the field
#' @param type The data type of the field
#' @param nullable A logical vector indicating whether or not the field is nullable
#' @return a Field object
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
#' field1 <- field("a", "integer", TRUE)
#' field2 <- field("b", "string", TRUE)
#' schema <- buildSchema(field1, field2)
#' df <- createDataFrame(sqlCtx, rdd, schema)
#' }
field <- function(name, type, nullable = TRUE) {
if (class(name) != "character") {
stop("Field name must be a string.")
}
if (class(type) != "character") {
stop("Field type must be a string.")
}
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
structure(list("name" = name, "type" = type, "nullable" = nullable), class = "field")
}

# print method for Field objects
print.field <- function(x, ...) {
cat("name = \"", x$name, "\", type = \"", x$type, "\", nullable = ", x$nullable, sep = "")
}

169 changes: 169 additions & 0 deletions pkg/R/schema.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#' structType
#'
#' Create a structType object that contains the metadata for a DataFrame. Intended for
#' use with createDataFrame and toDF.
#'
#' @param x a Field object (created with the field() function)
#' @param ... additional Field objects
#' @return a structType object
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
#' schema <- buildSchema(field("a", "integer"), field("b", "string"))
#' df <- createDataFrame(sqlCtx, rdd, schema)
#' }
structType <- function(x, ...) {
UseMethod("structType", x)
}

structType.jobj <- function(x) {
obj <- structure(list(), class = "structType")
obj$jobj <- x
obj$fields <- function() { lapply(callJMethod(x, "fields"), structField) }
obj
}

structType.structField <- function(x, ...) {
fields <- list(x, ...)
if (!all(sapply(fields, inherits, "structField"))) {
stop("All arguments must be structField objects.")
}
sfObjList <- lapply(fields, function(field) {
field$jobj
})
stObj <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils",
"createStructType",
listToSeq(sfObjList))
structType(stObj)
}

#' Print a Spark StructType.
#'
#' This function prints the contents of a StructType returned from the
#' SparkR JVM backend.
#'
#' @param x A StructType object
#' @param ... further arguments passed to or from other methods
print.structType <- function(x, ...) {
cat("StructType\n",
sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
"\", type = \"", field$dataType.toString(),
"\", nullable = ", field$nullable(), "\n",
sep = "") })
, sep = "")
}

#' structField
#'
#' Create a structField object that contains the metadata for a single field in a schema.
#'
#' @param x The name of the field
#' @param type The data type of the field
#' @param nullable A logical vector indicating whether or not the field is nullable
#' @return a Field object
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlCtx <- sparkRSQL.init(sc)
#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
#' field1 <- field("a", "integer", TRUE)
#' field2 <- field("b", "string", TRUE)
#' schema <- buildSchema(field1, field2)
#' df <- createDataFrame(sqlCtx, rdd, schema)
#' }

structField <- function(x, ...) {
UseMethod("structField", x)
}

structField.jobj <- function(x) {
obj <- structure(list(), class = "structField")
obj$jobj <- x
obj$name <- function() { callJMethod(x, "name") }
obj$dataType <- function() { callJMethod(x, "dataType") }
obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") }
obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") }
obj$nullable <- function() { callJMethod(x, "nullable") }
obj
}

structField.character <- function(x, type, nullable = TRUE) {
if (class(x) != "character") {
stop("Field name must be a string.")
}
if (class(type) != "character") {
stop("Field type must be a string.")
}
if (class(nullable) != "logical") {
stop("nullable must be either TRUE or FALSE")
}
options <- c("byte",
"integer",
"double",
"numeric",
"character",
"string",
"binary",
"raw",
"logical",
"boolean",
"timestamp",
"date")
dataType <- if (type %in% options) {
type
} else {
stop(paste("Unsupported type for Dataframe:", type))
}
sfObj <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils",
"createStructField",
x,
dataType,
nullable)
structField(sfObj)
}

#' Print a Spark StructField.
#'
#' This function prints the contents of a StructField returned from the
#' SparkR JVM backend.
#'
#' @param x A StructField object
#' @param ... further arguments passed to or from other methods
print.structField <- function(x, ...) {
cat("StructField(name = \"", x$name(),
"\", type = \"", x$dataType.toString(),
"\", nullable = ", x$nullable(),
")",
sep = "")
}

# cfreeman: Don't think we need this function since we can create
# structType in R and pass to createDataFrame
#
# #' dump the schema into JSON string
# tojson <- function(x) {
# if (inherits(x, "struct")) {
# # schema object
# l <- paste(lapply(x, tojson), collapse = ", ")
# paste('{\"type\":\"struct\", \"fields\":','[', l, ']}', sep = '')
# } else if (inherits(x, "field")) {
# # field object
# names <- names(x)
# items <- lapply(names, function(n) {
# safe_n <- gsub('"', '\\"', n)
# paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
# })
# d <- paste(items, collapse = ", ")
# paste('{', d, '}', sep = '')
# } else if (is.character(x)) {
# paste('"', x, '"', sep = '')
# } else if (is.logical(x)) {
# if (x) "true" else "false"
# } else {
# stop(paste("unexpected type:", class(x)))
# }
# }
31 changes: 28 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}

private[r] object SQLUtils {
Expand All @@ -39,8 +39,33 @@ private[r] object SQLUtils {
arr.toSeq
}

def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
def createStructType(fields : Seq[StructField]) : StructType = {
StructType(fields)
}

def DataTypeObject(dataType: String): DataType = {
dataType match {
case "byte" => org.apache.spark.sql.types.ByteType
case "integer" => org.apache.spark.sql.types.IntegerType
case "double" => org.apache.spark.sql.types.DoubleType
case "numeric" => org.apache.spark.sql.types.DoubleType
case "character" => org.apache.spark.sql.types.StringType
case "string" => org.apache.spark.sql.types.StringType
case "binary" => org.apache.spark.sql.types.BinaryType
case "raw" => org.apache.spark.sql.types.BinaryType
case "logical" => org.apache.spark.sql.types.BooleanType
case "boolean" => org.apache.spark.sql.types.BooleanType
case "timestamp" => org.apache.spark.sql.types.TimestampType
case "date" => org.apache.spark.sql.types.DateType
case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
}
}

def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
val dtObj = DataTypeObject(dataType)
StructField(name, dtObj, nullable)
}

val num = schema.fields.size
val rowRDD = rdd.map(bytesToRow)
sqlContext.createDataFrame(rowRDD, schema)
Expand Down

0 comments on commit be5d5c1

Please sign in to comment.